This is an automated email from the ASF dual-hosted git repository. wanghailin pushed a commit to branch dev in repository https://gitbox.apache.org/repos/asf/seatunnel.git
The following commit(s) were added to refs/heads/dev by this push: new 51ffc5a97e [Feature][Transform-v2] Add support for Zhipu AI in Embedding and LLM module (#8790) 51ffc5a97e is described below commit 51ffc5a97e4f7a75d1df47ecba9039a52d59abae Author: xiaochen <598457...@qq.com> AuthorDate: Tue Feb 25 10:47:29 2025 +0800 [Feature][Transform-v2] Add support for Zhipu AI in Embedding and LLM module (#8790) --- docs/en/transform-v2/embedding.md | 29 ++++++----- docs/en/transform-v2/llm.md | 2 +- docs/zh/transform-v2/embedding.md | 29 ++++++----- docs/zh/transform-v2/llm.md | 2 +- seatunnel-transforms-v2/pom.xml | 4 +- .../transform/nlpmodel/ModelProvider.java | 3 ++ .../transform/nlpmodel/ModelTransformConfig.java | 3 ++ .../nlpmodel/embedding/EmbeddingTransform.java | 13 +++++ .../embedding/EmbeddingTransformFactory.java | 4 ++ .../nlpmodel/embedding/remote/AbstractModel.java | 14 ++--- .../embedding/remote/custom/CustomModel.java | 6 +-- .../embedding/remote/doubao/DoubaoModel.java | 10 ++-- .../embedding/remote/openai/OpenAIModel.java | 10 ++-- .../embedding/remote/qianfan/QianfanModel.java | 10 ++-- .../DoubaoModel.java => zhipu/ZhipuModel.java} | 60 ++++++++++++++-------- .../transform/nlpmodel/llm/LLMTransform.java | 1 + .../embedding/EmbeddingRequestJsonTest.java | 25 ++++++++- 17 files changed, 147 insertions(+), 78 deletions(-) diff --git a/docs/en/transform-v2/embedding.md b/docs/en/transform-v2/embedding.md index 350a23fc55..cbebd535b1 100644 --- a/docs/en/transform-v2/embedding.md +++ b/docs/en/transform-v2/embedding.md @@ -10,20 +10,21 @@ different API endpoints. ## Options -| Name | Type | Required | Default Value | Description | -|--------------------------------|--------|----------|---------------|-------------------------------------------------------------------------------------------------------------| -| model_provider | enum | yes | - | The model provider for embedding. Options may include `QIANFAN`, `OPENAI`, etc. | -| api_key | string | yes | - | The API key required to authenticate with the embedding service. | -| secret_key | string | yes | - | The secret key required for additional authentication with the embedding service. | -| single_vectorized_input_number | int | no | 1 | The number of inputs vectorized in one request. Default is 1. | -| vectorization_fields | map | yes | - | A mapping between input fields and their corresponding output vector fields. | -| model | string | yes | - | The specific model to use for embedding (e.g: `text-embedding-3-small` for OPENAI). | -| api_path | string | no | - | The API endpoint for the embedding service. Typically provided by the model provider. | -| oauth_path | string | no | - | The API endpoint for the oauth service. | -| custom_config | map | no | | Custom configurations for the model. | -| custom_response_parse | string | no | | Specifies how to parse the response from the model using JsonPath. Example: `$.choices[*].message.content`. | -| custom_request_headers | map | no | | Custom headers for the request to the model. | -| custom_request_body | map | no | | Custom body for the request. Supports placeholders like `${model}`, `${input}`. | +| Name | Type | Required | Default Value | Description | +|----------------------------------|--------|----------|---------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| model_provider | enum | yes | - | The model provider for embedding. Options may include `QIANFAN`, `OPENAI`, etc. | +| api_key | string | yes | - | The API key required to authenticate with the embedding service. | +| secret_key | string | yes | - | The secret key required for additional authentication with the embedding service. | +| single_vectorized_input_number | int | no | 1 | The number of inputs vectorized in one request. Default is 1. | +| vectorization_fields | map | yes | - | A mapping between input fields and their corresponding output vector fields. | +| model | string | yes | - | The specific model to use for embedding (e.g: `text-embedding-3-small` for OPENAI). | +| api_path | string | no | - | The API endpoint for the embedding service. Typically provided by the model provider. | +| dimension | int | no | - | TThe vector dimension defaults to 2048. The Embedding-3 model supports custom vector dimensions, and it is recommended to choose dimensions of 256, 512, 1024, or 2048. | +| oauth_path | string | no | - | The API endpoint for the oauth service. | +| custom_config | map | no | | Custom configurations for the model. | +| custom_response_parse | string | no | | Specifies how to parse the response from the model using JsonPath. Example: `$.choices[*].message.content`. | +| custom_request_headers | map | no | | Custom headers for the request to the model. | +| custom_request_body | map | no | | Custom body for the request. Supports placeholders like `${model}`, `${input}`. | ### model_provider diff --git a/docs/en/transform-v2/llm.md b/docs/en/transform-v2/llm.md index 680121cb4d..0bc137bded 100644 --- a/docs/en/transform-v2/llm.md +++ b/docs/en/transform-v2/llm.md @@ -28,7 +28,7 @@ more. ### model_provider The model provider to use. The available options are: -OPENAI, DOUBAO, DEEPSEEK, KIMIAI, MICROSOFT, CUSTOM +OPENAI, DOUBAO, DEEPSEEK, KIMIAI, MICROSOFT, ZHIPU, CUSTOM > tips: If you use Microsoft, please make sure api_path cannot be empty diff --git a/docs/zh/transform-v2/embedding.md b/docs/zh/transform-v2/embedding.md index e05c9c2442..8ea2b68f70 100644 --- a/docs/zh/transform-v2/embedding.md +++ b/docs/zh/transform-v2/embedding.md @@ -8,20 +8,21 @@ ## 配置选项 -| 名称 | 类型 | 是否必填 | 默认值 | 描述 | -|--------------------------------|--------|------|-----|------------------------------------------------------------------| -| model_provider | enum | 是 | - | embedding模型的提供商。可选项包括 `QIANFAN`、`OPENAI` 等。 | -| api_key | string | 是 | - | 用于验证embedding服务的API密钥。 | -| secret_key | string | 是 | - | 用于额外验证的密钥。一些提供商可能需要此密钥进行安全的API请求。 | -| single_vectorized_input_number | int | 否 | 1 | 单次请求向量化的输入数量。默认值为1。 | -| vectorization_fields | map | 是 | - | 输入字段和相应的输出向量字段之间的映射。 | -| model | string | 是 | - | 要使用的具体embedding模型。例如,如果提供商为OPENAI,可以指定 `text-embedding-3-small`。 | -| api_path | string | 否 | - | embedding服务的API。通常由模型提供商提供。 | -| oauth_path | string | 否 | - | oauth 服务的 API 。 | -| custom_config | map | 否 | | 模型的自定义配置。 | -| custom_response_parse | string | 否 | | 使用 JsonPath 解析模型响应的方式。示例:`$.choices[*].message.content`。 | -| custom_request_headers | map | 否 | | 发送到模型的请求的自定义头信息。 | -| custom_request_body | map | 否 | | 请求体的自定义配置。支持占位符如 `${model}`、`${input}`。 | +| 名称 | 类型 | 是否必填 | 默认值 | 描述 | +|----------------------------------|--------|------|--------|--------------------------------------------------------------------| +| model_provider | enum | 是 | - | embedding模型的提供商。可选项包括 `QIANFAN`、`OPENAI` 等。 | +| api_key | string | 是 | - | 用于验证embedding服务的API密钥。 | +| secret_key | string | 是 | - | 用于额外验证的密钥。一些提供商可能需要此密钥进行安全的API请求。 | +| single_vectorized_input_number | int | 否 | 1 | 单次请求向量化的输入数量。默认值为1。 | +| vectorization_fields | map | 是 | - | 输入字段和相应的输出向量字段之间的映射。 | +| model | string | 是 | - | 要使用的具体embedding模型。例如,如果提供商为OPENAI,可以指定 `text-embedding-3-small`。 | +| api_path | string | 否 | - | embedding服务的API。通常由模型提供商提供。 | +| dimension | int | 否 | 2048 | 向量维度默认为 2048,Embedding-3模型支持自定义向量维度,建议选择256、512、1024或2048维度。 | +| oauth_path | string | 否 | - | oauth 服务的 API 。 | +| custom_config | map | 否 | | 模型的自定义配置。 | +| custom_response_parse | string | 否 | | 使用 JsonPath 解析模型响应的方式。示例:`$.choices[*].message.content`。 | +| custom_request_headers | map | 否 | | 发送到模型的请求的自定义头信息。 | +| custom_request_body | map | 否 | | 请求体的自定义配置。支持占位符如 `${model}`、`${input}`。 | ### embedding_model_provider diff --git a/docs/zh/transform-v2/llm.md b/docs/zh/transform-v2/llm.md index c1d05d59a3..c6cead3dfd 100644 --- a/docs/zh/transform-v2/llm.md +++ b/docs/zh/transform-v2/llm.md @@ -26,7 +26,7 @@ ### model_provider 要使用的模型提供者。可用选项为: -OPENAI,DOUBAO,DEEPSEEK,KIMIAI,MICROSOFT, CUSTOM +OPENAI,DOUBAO,DEEPSEEK,KIMIAI,MICROSOFT, ZHIPU, CUSTOM > tips: 如果使用 Microsoft, 请确保 api_path 配置不能为空 diff --git a/seatunnel-transforms-v2/pom.xml b/seatunnel-transforms-v2/pom.xml index f15c1aae40..5f74ad156b 100644 --- a/seatunnel-transforms-v2/pom.xml +++ b/seatunnel-transforms-v2/pom.xml @@ -32,6 +32,8 @@ <properties> <httpclient.version>4.5.13</httpclient.version> <httpcore.version>4.4.4</httpcore.version> + <mockwebserver.version>3.6.0</mockwebserver.version> + <zhipu.version>release-V4-2.3.0</zhipu.version> </properties> <dependencyManagement> @@ -95,7 +97,7 @@ <dependency> <groupId>com.squareup.okhttp3</groupId> <artifactId>mockwebserver</artifactId> - <version>3.6.0</version> + <version>${mockwebserver.version}</version> <scope>test</scope> </dependency> </dependencies> diff --git a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelProvider.java b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelProvider.java index f18ffdfc8e..aaeaee90ad 100644 --- a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelProvider.java +++ b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelProvider.java @@ -28,6 +28,9 @@ public enum ModelProvider { KIMIAI("https://api.moonshot.cn/v1/chat/completions", ""), DEEPSEEK("https://api.deepseek.com/chat/completions", ""), MICROSOFT("", ""), + ZHIPU( + "https://open.bigmodel.cn/api/paas/v4/chat/completions", + "https://open.bigmodel.cn/api/paas/v4/embeddings"), CUSTOM("", ""), LOCAL("", ""); diff --git a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelTransformConfig.java b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelTransformConfig.java index b123459750..c3709c70b6 100644 --- a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelTransformConfig.java +++ b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelTransformConfig.java @@ -79,6 +79,9 @@ public class ModelTransformConfig implements Serializable { .withFallbackKeys("inference_batch_size") .withDescription("The row batch size of each process"); + public static final Option<Integer> DIMENSION = + Options.key("dimension").intType().defaultValue(2048).withDescription("dimension"); + public static class CustomRequestConfig { // Custom response parsing diff --git a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransform.java b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransform.java index c699c6bfe8..6a8729a198 100644 --- a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransform.java +++ b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransform.java @@ -33,6 +33,7 @@ import org.apache.seatunnel.transform.nlpmodel.embedding.remote.custom.CustomMod import org.apache.seatunnel.transform.nlpmodel.embedding.remote.doubao.DoubaoModel; import org.apache.seatunnel.transform.nlpmodel.embedding.remote.openai.OpenAIModel; import org.apache.seatunnel.transform.nlpmodel.embedding.remote.qianfan.QianfanModel; +import org.apache.seatunnel.transform.nlpmodel.embedding.remote.zhipu.ZhipuModel; import org.apache.seatunnel.transform.nlpmodel.llm.LLMTransformConfig; import lombok.NonNull; @@ -136,6 +137,18 @@ public class EmbeddingTransform extends MultipleFieldOutputTransform { EmbeddingTransformConfig .SINGLE_VECTORIZED_INPUT_NUMBER)); break; + case ZHIPU: + model = + new ZhipuModel( + config.get(ModelTransformConfig.API_KEY), + config.get(ModelTransformConfig.MODEL), + provider.usedEmbeddingPath( + config.get(ModelTransformConfig.API_PATH)), + config.get(ModelTransformConfig.DIMENSION), + config.get( + EmbeddingTransformConfig + .SINGLE_VECTORIZED_INPUT_NUMBER)); + break; case LOCAL: default: throw new IllegalArgumentException("Unsupported model provider: " + provider); diff --git a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransformFactory.java b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransformFactory.java index 56e252e1bb..5f8e397e69 100644 --- a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransformFactory.java +++ b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransformFactory.java @@ -62,6 +62,10 @@ public class EmbeddingTransformFactory implements TableTransformFactory { LLMTransformConfig.MODEL_PROVIDER, ModelProvider.CUSTOM, LLMTransformConfig.CustomRequestConfig.CUSTOM_CONFIG) + .conditional( + EmbeddingTransformConfig.MODEL_PROVIDER, + ModelProvider.ZHIPU, + EmbeddingTransformConfig.DIMENSION) .optional(TransformCommonOptions.MULTI_TABLES) .optional(TransformCommonOptions.TABLE_MATCH_REGEX) .build(); diff --git a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/AbstractModel.java b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/AbstractModel.java index 0803dfd7ad..a53fa4684c 100644 --- a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/AbstractModel.java +++ b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/AbstractModel.java @@ -42,23 +42,23 @@ public abstract class AbstractModel implements Model { public List<ByteBuffer> vectorization(Object[] fields) throws IOException { List<ByteBuffer> result = new ArrayList<>(); - List<List<Float>> vectors = batchProcess(fields, singleVectorizedInputNumber); - for (List<Float> vector : vectors) { - result.add(BufferUtils.toByteBuffer(vector.toArray(new Float[0]))); + List<List<Double>> vectors = batchProcess(fields, singleVectorizedInputNumber); + for (List<Double> vector : vectors) { + result.add(BufferUtils.toByteBuffer(vector.toArray(new Double[0]))); } return result; } - protected abstract List<List<Float>> vector(Object[] fields) throws IOException; + protected abstract List<List<Double>> vector(Object[] fields) throws IOException; - public List<List<Float>> batchProcess(Object[] array, int batchSize) throws IOException { - List<List<Float>> merged = new ArrayList<>(); + public List<List<Double>> batchProcess(Object[] array, int batchSize) throws IOException { + List<List<Double>> merged = new ArrayList<>(); if (array == null || array.length == 0) { return merged; } for (int i = 0; i < array.length; i += batchSize) { Object[] batch = ArrayUtils.subarray(array, i, i + batchSize); - List<List<Float>> vector = vector(batch); + List<List<Double>> vector = vector(batch); merged.addAll(vector); } if (array.length != merged.size()) { diff --git a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/custom/CustomModel.java b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/custom/CustomModel.java index 179315f956..ea39f15462 100644 --- a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/custom/CustomModel.java +++ b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/custom/CustomModel.java @@ -67,7 +67,7 @@ public class CustomModel extends AbstractModel { } @Override - protected List<List<Float>> vector(Object[] fields) throws IOException { + protected List<List<Double>> vector(Object[] fields) throws IOException { return vectorGeneration(fields); } @@ -76,7 +76,7 @@ public class CustomModel extends AbstractModel { return vectorGeneration(new Object[] {DIMENSION_EXAMPLE}).size(); } - private List<List<Float>> vectorGeneration(Object[] fields) throws IOException { + private List<List<Double>> vectorGeneration(Object[] fields) throws IOException { HttpPost post = new HttpPost(apiPath); // Construct a request with custom parameters for (Map.Entry<String, String> entry : header.entrySet()) { @@ -96,7 +96,7 @@ public class CustomModel extends AbstractModel { } return OBJECT_MAPPER.convertValue( - parseResponse(responseStr), new TypeReference<List<List<Float>>>() {}); + parseResponse(responseStr), new TypeReference<List<List<Double>>>() {}); } @VisibleForTesting diff --git a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java index f2b1e348c7..1591cd587e 100644 --- a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java +++ b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java @@ -54,7 +54,7 @@ public class DoubaoModel extends AbstractModel { } @Override - protected List<List<Float>> vector(Object[] fields) throws IOException { + protected List<List<Double>> vector(Object[] fields) throws IOException { return vectorGeneration(fields); } @@ -63,7 +63,7 @@ public class DoubaoModel extends AbstractModel { return vectorGeneration(new Object[] {DIMENSION_EXAMPLE}).size(); } - private List<List<Float>> vectorGeneration(Object[] fields) throws IOException { + private List<List<Double>> vectorGeneration(Object[] fields) throws IOException { HttpPost post = new HttpPost(apiPath); post.setHeader("Authorization", "Bearer " + apiKey); post.setHeader("Content-Type", "application/json"); @@ -82,14 +82,14 @@ public class DoubaoModel extends AbstractModel { } JsonNode data = OBJECT_MAPPER.readTree(responseStr).get("data"); - List<List<Float>> embeddings = new ArrayList<>(); + List<List<Double>> embeddings = new ArrayList<>(); if (data.isArray()) { for (JsonNode node : data) { JsonNode embeddingNode = node.get("embedding"); - List<Float> embedding = + List<Double> embedding = OBJECT_MAPPER.readValue( - embeddingNode.traverse(), new TypeReference<List<Float>>() {}); + embeddingNode.traverse(), new TypeReference<List<Double>>() {}); embeddings.add(embedding); } } diff --git a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/openai/OpenAIModel.java b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/openai/OpenAIModel.java index 2a45cc829f..467d6cb406 100644 --- a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/openai/OpenAIModel.java +++ b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/openai/OpenAIModel.java @@ -53,7 +53,7 @@ public class OpenAIModel extends AbstractModel { } @Override - protected List<List<Float>> vector(Object[] fields) throws IOException { + protected List<List<Double>> vector(Object[] fields) throws IOException { if (fields.length > 1) { throw new IllegalArgumentException("OpenAI model only supports single input"); } @@ -65,7 +65,7 @@ public class OpenAIModel extends AbstractModel { return vectorGeneration(new Object[] {DIMENSION_EXAMPLE}).size(); } - private List<List<Float>> vectorGeneration(Object[] fields) throws IOException { + private List<List<Double>> vectorGeneration(Object[] fields) throws IOException { HttpPost post = new HttpPost(apiPath); post.setHeader("Authorization", "Bearer " + apiKey); post.setHeader("Content-Type", "application/json"); @@ -84,14 +84,14 @@ public class OpenAIModel extends AbstractModel { } JsonNode data = OBJECT_MAPPER.readTree(responseStr).get("data"); - List<List<Float>> embeddings = new ArrayList<>(); + List<List<Double>> embeddings = new ArrayList<>(); if (data.isArray()) { for (JsonNode node : data) { JsonNode embeddingNode = node.get("embedding"); - List<Float> embedding = + List<Double> embedding = OBJECT_MAPPER.readValue( - embeddingNode.traverse(), new TypeReference<List<Float>>() {}); + embeddingNode.traverse(), new TypeReference<List<Double>>() {}); embeddings.add(embedding); } } diff --git a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/qianfan/QianfanModel.java b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/qianfan/QianfanModel.java index f85619eb3e..67c1a8147a 100644 --- a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/qianfan/QianfanModel.java +++ b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/qianfan/QianfanModel.java @@ -100,7 +100,7 @@ public class QianfanModel extends AbstractModel { } @Override - public List<List<Float>> vector(Object[] fields) throws IOException { + public List<List<Double>> vector(Object[] fields) throws IOException { return vectorGeneration(fields); } @@ -109,7 +109,7 @@ public class QianfanModel extends AbstractModel { return vectorGeneration(new Object[] {DIMENSION_EXAMPLE}).get(0).size(); } - private List<List<Float>> vectorGeneration(Object[] fields) throws IOException { + private List<List<Double>> vectorGeneration(Object[] fields) throws IOException { String formattedApiPath = String.format( (apiPath.endsWith("/") ? apiPath : apiPath + "/") + "%s?access_token=%s", @@ -143,14 +143,14 @@ public class QianfanModel extends AbstractModel { "Failed to get vector from qianfan, response: " + result.get("error_msg")); } - List<List<Float>> embeddings = new ArrayList<>(); + List<List<Double>> embeddings = new ArrayList<>(); JsonNode data = result.get("data"); if (data.isArray()) { for (JsonNode node : data) { - List<Float> embedding = + List<Double> embedding = OBJECT_MAPPER.readValue( node.get("embedding").traverse(), - new TypeReference<List<Float>>() {}); + new TypeReference<List<Double>>() {}); embeddings.add(embedding); } } diff --git a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/zhipu/ZhipuModel.java similarity index 68% copy from seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java copy to seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/zhipu/ZhipuModel.java index f2b1e348c7..df72261bb5 100644 --- a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java +++ b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/zhipu/ZhipuModel.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.seatunnel.transform.nlpmodel.embedding.remote.doubao; +package org.apache.seatunnel.transform.nlpmodel.embedding.remote.zhipu; import org.apache.seatunnel.shade.com.fasterxml.jackson.core.type.TypeReference; import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.JsonNode; @@ -25,6 +25,7 @@ import org.apache.seatunnel.shade.com.google.common.annotations.VisibleForTestin import org.apache.seatunnel.transform.nlpmodel.embedding.remote.AbstractModel; +import org.apache.http.HttpHeaders; import org.apache.http.client.config.RequestConfig; import org.apache.http.client.methods.CloseableHttpResponse; import org.apache.http.client.methods.HttpPost; @@ -34,62 +35,77 @@ import org.apache.http.impl.client.HttpClients; import org.apache.http.util.EntityUtils; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; import java.util.List; -public class DoubaoModel extends AbstractModel { +/** Zhipu model. Refer <a href="https://bigmodel.cn/dev/api/vector/embedding">embedding api </a> */ +public class ZhipuModel extends AbstractModel { private final CloseableHttpClient client; - private final String apiKey; private final String model; - private final String apiPath; - - public DoubaoModel(String apiKey, String model, String apiPath, Integer vectorizedNumber) { + private final String apiKey; + private final String apiPath;; + private final Integer dimension; + private final Integer MAX_INPUT_SIZE = 64; + + public ZhipuModel( + String apiKey, + String model, + String apiPath, + Integer dimension, + Integer vectorizedNumber) + throws IOException { super(vectorizedNumber); - this.apiKey = apiKey; this.model = model; + this.apiKey = apiKey; this.apiPath = apiPath; + this.dimension = dimension; this.client = HttpClients.createDefault(); } @Override - protected List<List<Float>> vector(Object[] fields) throws IOException { + public List<List<Double>> vector(Object[] fields) throws IOException { return vectorGeneration(fields); } @Override public Integer dimension() throws IOException { - return vectorGeneration(new Object[] {DIMENSION_EXAMPLE}).size(); + return dimension; } - private List<List<Float>> vectorGeneration(Object[] fields) throws IOException { + private List<List<Double>> vectorGeneration(Object[] fields) throws IOException { + + if (fields == null || fields.length > MAX_INPUT_SIZE) { + throw new IOException( + "Zhipu input text for vectorization, with a maximum limit of 64 entries."); + } HttpPost post = new HttpPost(apiPath); - post.setHeader("Authorization", "Bearer " + apiKey); - post.setHeader("Content-Type", "application/json"); + post.setHeader(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey); + post.setHeader(HttpHeaders.CONTENT_TYPE, "application/json"); post.setConfig( RequestConfig.custom().setConnectTimeout(20000).setSocketTimeout(20000).build()); post.setEntity( new StringEntity( - OBJECT_MAPPER.writeValueAsString(createJsonNodeFromData(fields)), "UTF-8")); + OBJECT_MAPPER.writeValueAsString(createJsonNodeFromData(fields)), + StandardCharsets.UTF_8.name())); CloseableHttpResponse response = client.execute(post); String responseStr = EntityUtils.toString(response.getEntity()); - if (response.getStatusLine().getStatusCode() != 200) { - throw new IOException("Failed to get vector from doubao, response: " + responseStr); + throw new IOException("Failed to get vector from zhipu, response: " + responseStr); } - JsonNode data = OBJECT_MAPPER.readTree(responseStr).get("data"); - List<List<Float>> embeddings = new ArrayList<>(); + List<List<Double>> embeddings = new ArrayList<>(); if (data.isArray()) { for (JsonNode node : data) { JsonNode embeddingNode = node.get("embedding"); - List<Float> embedding = + List<Double> embedding = OBJECT_MAPPER.readValue( - embeddingNode.traverse(), new TypeReference<List<Float>>() {}); + embeddingNode.traverse(), new TypeReference<List<Double>>() {}); embeddings.add(embedding); } } @@ -99,7 +115,11 @@ public class DoubaoModel extends AbstractModel { @VisibleForTesting public ObjectNode createJsonNodeFromData(Object[] fields) { ArrayNode arrayNode = OBJECT_MAPPER.valueToTree(Arrays.asList(fields)); - return OBJECT_MAPPER.createObjectNode().put("model", model).set("input", arrayNode); + return OBJECT_MAPPER + .createObjectNode() + .put("model", model) + .put("dimensions", dimension) + .set("input", arrayNode); } @Override diff --git a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransform.java b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransform.java index 048e1bffce..8160cdc647 100644 --- a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransform.java +++ b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransform.java @@ -109,6 +109,7 @@ public class LLMTransform extends SingleFieldOutputTransform { case DEEPSEEK: case OPENAI: case DOUBAO: + case ZHIPU: model = new OpenAIModel( inputCatalogTable.getSeaTunnelRowType(), diff --git a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/EmbeddingRequestJsonTest.java b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/EmbeddingRequestJsonTest.java index dc43cdeb23..46c893d182 100644 --- a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/EmbeddingRequestJsonTest.java +++ b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/EmbeddingRequestJsonTest.java @@ -26,6 +26,7 @@ import org.apache.seatunnel.transform.nlpmodel.embedding.remote.custom.CustomMod import org.apache.seatunnel.transform.nlpmodel.embedding.remote.doubao.DoubaoModel; import org.apache.seatunnel.transform.nlpmodel.embedding.remote.openai.OpenAIModel; import org.apache.seatunnel.transform.nlpmodel.embedding.remote.qianfan.QianfanModel; +import org.apache.seatunnel.transform.nlpmodel.embedding.remote.zhipu.ZhipuModel; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; @@ -93,6 +94,26 @@ public class EmbeddingRequestJsonTest { model.close(); } + @Test + void testZhipuRequestJson() throws IOException { + ZhipuModel model = + new ZhipuModel( + "apikey", + "modelName", + "https://open.bigmodel.cn/api/paas/v4/embeddings", + 64, + 1); + ObjectNode node = + model.createJsonNodeFromData( + new Object[] { + "Determine whether someone is Chinese or American by their name" + }); + Assertions.assertEquals( + "{\"model\":\"modelName\",\"dimensions\":64,\"input\":[\"Determine whether someone is Chinese or American by their name\"]}", + OBJECT_MAPPER.writeValueAsString(node)); + model.close(); + } + @Test void testCustomRequestJson() throws IOException { Map<String, String> header = new HashMap<>(); @@ -131,11 +152,11 @@ public class EmbeddingRequestJsonTest { new HashMap<>(), "$.data[*].embedding", 1); - List<List<Float>> lists = + List<List<Double>> lists = OBJECT_MAPPER.convertValue( customModel.parseResponse( "{\"created\":1725001256,\"id\":\"02172500125677376580aba8475a41c550bbf05104842f0405ef5\",\"data\":[{\"embedding\":[-1.625,0.07958984375,-1.5703125,-3.03125,-1.4609375,3.46875,-0.73046875,-2.578125,-0.66796875,1.71875,0.361328125,2,5.125,2.25,4.6875,1.4921875,-0.77734375,-0.466796875,0.0439453125,-2.46875,3.59375,4.96875,2.34375,-5.34375,0.11083984375,-5.875,3.0625,4.09375,3.4375,0.2265625,9,-1.9296875,2.25,0.765625,3.671875,-2.484375,-1.171875,-1.6171875, [...] - new TypeReference<List<List<Float>>>() {}); + new TypeReference<List<List<Double>>>() {}); Assertions.assertEquals(2, lists.size()); Assertions.assertEquals(2560, lists.get(0).size()); }