This is an automated email from the ASF dual-hosted git repository.

liugddx pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/seatunnel.git


The following commit(s) were added to refs/heads/dev by this push:
     new 5b5ee84130 [Improve][Transform] Add LLM model provider microsoft 
(#7778)
5b5ee84130 is described below

commit 5b5ee8413052730cf5f361b701465b6f4fbfb219
Author: corgy-w <73771213+corg...@users.noreply.github.com>
AuthorDate: Tue Oct 8 09:35:23 2024 +0800

    [Improve][Transform] Add LLM model provider microsoft (#7778)
    
    Co-authored-by: Jia Fan <fanjiaemi...@qq.com>
---
 docs/en/transform-v2/llm.md                        |   7 +-
 docs/zh/transform-v2/llm.md                        |   4 +-
 .../apache/seatunnel/e2e/transform/TestLLMIT.java  |   7 ++
 .../test/resources/llm_microsoft_transform.conf    |  75 +++++++++++++++
 .../src/test/resources/mockserver-config.json      |  32 +++++++
 .../transform/nlpmodel/ModelProvider.java          |   1 +
 .../transform/nlpmodel/llm/LLMTransform.java       |  12 +++
 .../nlpmodel/llm/LLMTransformFactory.java          |   8 +-
 .../llm/remote/microsoft/MicrosoftModel.java       | 103 +++++++++++++++++++++
 .../transform/llm/LLMRequestJsonTest.java          |  34 +++++++
 10 files changed, 277 insertions(+), 6 deletions(-)

diff --git a/docs/en/transform-v2/llm.md b/docs/en/transform-v2/llm.md
index 8ee5a36a9a..81dc9b3c70 100644
--- a/docs/en/transform-v2/llm.md
+++ b/docs/en/transform-v2/llm.md
@@ -11,7 +11,7 @@ more.
 ## Options
 
 | name                   | type   | required | default value |
-|------------------------| ------ | -------- |---------------|
+|------------------------|--------|----------|---------------|
 | model_provider         | enum   | yes      |               |
 | output_data_type       | enum   | no       | String        |
 | output_column_name     | string | no       | llm_output    |
@@ -28,7 +28,9 @@ more.
 ### model_provider
 
 The model provider to use. The available options are:
-OPENAI, DOUBAO, KIMIAI, CUSTOM
+OPENAI, DOUBAO, KIMIAI, MICROSOFT, CUSTOM
+
+> tips: If you use Microsoft, please make sure api_path cannot be empty
 
 ### output_data_type
 
@@ -254,6 +256,7 @@ sink {
   }
 }
 ```
+
 ### Customize the LLM model
 
 ```hocon
diff --git a/docs/zh/transform-v2/llm.md b/docs/zh/transform-v2/llm.md
index c6f7aeefea..5ab37f5870 100644
--- a/docs/zh/transform-v2/llm.md
+++ b/docs/zh/transform-v2/llm.md
@@ -26,7 +26,9 @@
 ### model_provider
 
 要使用的模型提供者。可用选项为:
-OPENAI、DOUBAO、KIMIAI、CUSTOM
+OPENAI、DOUBAO、KIMIAI、MICROSOFT, CUSTOM
+
+> tips: 如果使用 Microsoft, 请确保 api_path 配置不能为空
 
 ### output_data_type
 
diff --git 
a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/java/org/apache/seatunnel/e2e/transform/TestLLMIT.java
 
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/java/org/apache/seatunnel/e2e/transform/TestLLMIT.java
index d98a5e7e33..f739e7af96 100644
--- 
a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/java/org/apache/seatunnel/e2e/transform/TestLLMIT.java
+++ 
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/java/org/apache/seatunnel/e2e/transform/TestLLMIT.java
@@ -88,6 +88,13 @@ public class TestLLMIT extends TestSuiteBase implements 
TestResource {
         Assertions.assertEquals(0, execResult.getExitCode());
     }
 
+    @TestTemplate
+    public void testLLMWithMicrosoft(TestContainer container)
+            throws IOException, InterruptedException {
+        Container.ExecResult execResult = 
container.executeJob("/llm_microsoft_transform.conf");
+        Assertions.assertEquals(0, execResult.getExitCode());
+    }
+
     @TestTemplate
     public void testLLMWithOpenAIBoolean(TestContainer container)
             throws IOException, InterruptedException {
diff --git 
a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/llm_microsoft_transform.conf
 
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/llm_microsoft_transform.conf
new file mode 100644
index 0000000000..37205a3aca
--- /dev/null
+++ 
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/llm_microsoft_transform.conf
@@ -0,0 +1,75 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+######
+###### This config file is a demonstration of streaming processing in 
seatunnel config
+######
+
+env {
+  job.mode = "BATCH"
+}
+
+source {
+  FakeSource {
+    row.num = 5
+    schema = {
+      fields {
+        id = "int"
+        name = "string"
+      }
+    }
+    rows = [
+      {fields = [1, "Jia Fan"], kind = INSERT}
+      {fields = [2, "Hailin Wang"], kind = INSERT}
+      {fields = [3, "Tomas"], kind = INSERT}
+      {fields = [4, "Eric"], kind = INSERT}
+      {fields = [5, "Guangdong Liu"], kind = INSERT}
+    ]
+    result_table_name = "fake"
+  }
+}
+
+transform {
+  LLM {
+    source_table_name = "fake"
+    model_provider = MICROSOFT
+    model = gpt-35-turbo
+    api_key = sk-xxx
+    prompt = "Determine whether someone is Chinese or American by their name"
+    api_path = 
"http://mockserver:1080/openai/deployments/${model}/chat/completions?api-version=2024-02-01";
+    result_table_name = "llm_output"
+  }
+}
+
+sink {
+  Assert {
+    source_table_name = "llm_output"
+    rules =
+      {
+        field_rules = [
+          {
+            field_name = llm_output
+            field_type = string
+            field_value = [
+              {
+                rule_type = NOT_NULL
+              }
+            ]
+          }
+        ]
+      }
+  }
+}
\ No newline at end of file
diff --git 
a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/mockserver-config.json
 
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/mockserver-config.json
index 44dd94396e..ffdb409c9c 100644
--- 
a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/mockserver-config.json
+++ 
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/mockserver-config.json
@@ -104,5 +104,37 @@
         "Content-Type": "application/json"
       }
     }
+  },
+  {
+    "httpRequest": {
+      "method": "POST",
+      "path": "/openai/deployments/gpt-35-turbo/chat/.*"
+    },
+    "httpResponse": {
+      "body": {
+        "id": "chatcmpl-6v7mkQj980V1yBec6ETrKPRqFjNw9",
+        "object": "chat.completion",
+        "created": 1679072642,
+        "model": "gpt-35-turbo",
+        "usage": {
+          "prompt_tokens": 58,
+          "completion_tokens": 68,
+          "total_tokens": 126
+        },
+        "choices": [
+          {
+            "message": {
+              "role": "assistant",
+              "content": "[\"Chinese\"]"
+            },
+            "finish_reason": "stop",
+            "index": 0
+          }
+        ]
+      },
+      "headers": {
+        "Content-Type": "application/json"
+      }
+    }
   }
 ]
diff --git 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelProvider.java
 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelProvider.java
index ce22bc5a6d..3172137706 100644
--- 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelProvider.java
+++ 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelProvider.java
@@ -26,6 +26,7 @@ public enum ModelProvider {
             "https://ark.cn-beijing.volces.com/api/v3/embeddings";),
     QIANFAN("", 
"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings";),
     KIMIAI("https://api.moonshot.cn/v1/chat/completions";, ""),
+    MICROSOFT("", ""),
     CUSTOM("", ""),
     LOCAL("", "");
 
diff --git 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransform.java
 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransform.java
index 08ae42e443..069945951b 100644
--- 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransform.java
+++ 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransform.java
@@ -31,6 +31,7 @@ import 
org.apache.seatunnel.transform.nlpmodel.ModelTransformConfig;
 import org.apache.seatunnel.transform.nlpmodel.llm.remote.Model;
 import org.apache.seatunnel.transform.nlpmodel.llm.remote.custom.CustomModel;
 import org.apache.seatunnel.transform.nlpmodel.llm.remote.kimiai.KimiAIModel;
+import 
org.apache.seatunnel.transform.nlpmodel.llm.remote.microsoft.MicrosoftModel;
 import org.apache.seatunnel.transform.nlpmodel.llm.remote.openai.OpenAIModel;
 
 import lombok.NonNull;
@@ -94,6 +95,17 @@ public class LLMTransform extends SingleFieldOutputTransform 
{
                                         LLMTransformConfig.CustomRequestConfig
                                                 .CUSTOM_RESPONSE_PARSE));
                 break;
+            case MICROSOFT:
+                model =
+                        new MicrosoftModel(
+                                inputCatalogTable.getSeaTunnelRowType(),
+                                outputDataType.getSqlType(),
+                                
config.get(LLMTransformConfig.INFERENCE_COLUMNS),
+                                config.get(LLMTransformConfig.PROMPT),
+                                config.get(LLMTransformConfig.MODEL),
+                                config.get(LLMTransformConfig.API_KEY),
+                                
provider.usedLLMPath(config.get(LLMTransformConfig.API_PATH)));
+                break;
             case OPENAI:
             case DOUBAO:
                 model =
diff --git 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransformFactory.java
 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransformFactory.java
index eda57e1275..834c0b4d17 100644
--- 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransformFactory.java
+++ 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransformFactory.java
@@ -26,7 +26,6 @@ import org.apache.seatunnel.api.table.factory.Factory;
 import org.apache.seatunnel.api.table.factory.TableTransformFactory;
 import org.apache.seatunnel.api.table.factory.TableTransformFactoryContext;
 import org.apache.seatunnel.transform.nlpmodel.ModelProvider;
-import org.apache.seatunnel.transform.nlpmodel.ModelTransformConfig;
 
 import com.google.auto.service.AutoService;
 
@@ -50,14 +49,17 @@ public class LLMTransformFactory implements 
TableTransformFactory {
                         LLMTransformConfig.PROCESS_BATCH_SIZE)
                 .conditional(
                         LLMTransformConfig.MODEL_PROVIDER,
-                        Lists.newArrayList(ModelProvider.OPENAI, 
ModelProvider.DOUBAO),
+                        Lists.newArrayList(
+                                ModelProvider.OPENAI,
+                                ModelProvider.DOUBAO,
+                                ModelProvider.MICROSOFT),
                         LLMTransformConfig.API_KEY)
                 .conditional(
                         LLMTransformConfig.MODEL_PROVIDER,
                         ModelProvider.QIANFAN,
                         LLMTransformConfig.API_KEY,
                         LLMTransformConfig.SECRET_KEY,
-                        ModelTransformConfig.OAUTH_PATH)
+                        LLMTransformConfig.OAUTH_PATH)
                 .conditional(
                         LLMTransformConfig.MODEL_PROVIDER,
                         ModelProvider.CUSTOM,
diff --git 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/microsoft/MicrosoftModel.java
 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/microsoft/MicrosoftModel.java
new file mode 100644
index 0000000000..b6362c41a3
--- /dev/null
+++ 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/microsoft/MicrosoftModel.java
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.seatunnel.transform.nlpmodel.llm.remote.microsoft;
+
+import 
org.apache.seatunnel.shade.com.fasterxml.jackson.core.type.TypeReference;
+import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.JsonNode;
+import 
org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.ArrayNode;
+import 
org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.ObjectNode;
+
+import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
+import org.apache.seatunnel.api.table.type.SqlType;
+import org.apache.seatunnel.transform.nlpmodel.CustomConfigPlaceholder;
+import org.apache.seatunnel.transform.nlpmodel.llm.remote.AbstractModel;
+
+import org.apache.http.client.config.RequestConfig;
+import org.apache.http.client.methods.CloseableHttpResponse;
+import org.apache.http.client.methods.HttpPost;
+import org.apache.http.entity.StringEntity;
+import org.apache.http.impl.client.CloseableHttpClient;
+import org.apache.http.impl.client.HttpClients;
+import org.apache.http.util.EntityUtils;
+
+import com.google.common.annotations.VisibleForTesting;
+
+import java.io.IOException;
+import java.util.List;
+
+public class MicrosoftModel extends AbstractModel {
+
+    private final CloseableHttpClient client;
+    private final String apiKey;
+    private final String model;
+    private final String apiPath;
+
+    public MicrosoftModel(
+            SeaTunnelRowType rowType,
+            SqlType outputType,
+            List<String> projectionColumns,
+            String prompt,
+            String model,
+            String apiKey,
+            String apiPath) {
+        super(rowType, outputType, projectionColumns, prompt);
+        this.model = model;
+        this.apiKey = apiKey;
+        this.apiPath =
+                CustomConfigPlaceholder.replacePlaceholders(
+                        apiPath, 
CustomConfigPlaceholder.REPLACE_PLACEHOLDER_MODEL, model, null);
+        this.client = HttpClients.createDefault();
+    }
+
+    @Override
+    protected List<String> chatWithModel(String prompt, String data) throws 
IOException {
+        HttpPost post = new HttpPost(apiPath);
+        post.setHeader("Authorization", "Bearer " + apiKey);
+        post.setHeader("Content-Type", "application/json");
+        ObjectNode objectNode = createJsonNodeFromData(prompt, data);
+        post.setEntity(new 
StringEntity(OBJECT_MAPPER.writeValueAsString(objectNode), "UTF-8"));
+        post.setConfig(
+                
RequestConfig.custom().setConnectTimeout(20000).setSocketTimeout(20000).build());
+        CloseableHttpResponse response = client.execute(post);
+        String responseStr = EntityUtils.toString(response.getEntity());
+        if (response.getStatusLine().getStatusCode() != 200) {
+            throw new IOException("Failed to chat with model, response: " + 
responseStr);
+        }
+
+        JsonNode result = OBJECT_MAPPER.readTree(responseStr);
+        String resultData = 
result.get("choices").get(0).get("message").get("content").asText();
+        return OBJECT_MAPPER.readValue(
+                convertData(resultData), new TypeReference<List<String>>() {});
+    }
+
+    @VisibleForTesting
+    public ObjectNode createJsonNodeFromData(String prompt, String data) {
+        ObjectNode objectNode = OBJECT_MAPPER.createObjectNode();
+        ArrayNode messages = objectNode.putArray("messages");
+        messages.addObject().put("role", "system").put("content", prompt);
+        messages.addObject().put("role", "user").put("content", data);
+        return objectNode;
+    }
+
+    @Override
+    public void close() throws IOException {
+        if (client != null) {
+            client.close();
+        }
+    }
+}
diff --git 
a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/llm/LLMRequestJsonTest.java
 
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/llm/LLMRequestJsonTest.java
index 91666c4139..870af980fe 100644
--- 
a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/llm/LLMRequestJsonTest.java
+++ 
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/llm/LLMRequestJsonTest.java
@@ -28,6 +28,7 @@ import org.apache.seatunnel.api.table.type.SqlType;
 import org.apache.seatunnel.format.json.RowToJsonConverters;
 import org.apache.seatunnel.transform.nlpmodel.llm.remote.custom.CustomModel;
 import org.apache.seatunnel.transform.nlpmodel.llm.remote.kimiai.KimiAIModel;
+import 
org.apache.seatunnel.transform.nlpmodel.llm.remote.microsoft.MicrosoftModel;
 import org.apache.seatunnel.transform.nlpmodel.llm.remote.openai.OpenAIModel;
 
 import org.junit.jupiter.api.Assertions;
@@ -36,6 +37,7 @@ import org.junit.jupiter.api.Test;
 import com.google.common.collect.Lists;
 
 import java.io.IOException;
+import java.lang.reflect.Field;
 import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.List;
@@ -130,6 +132,38 @@ public class LLMRequestJsonTest {
         model.close();
     }
 
+    @Test
+    void testMicrosoftRequestJson() throws Exception {
+        SeaTunnelRowType rowType =
+                new SeaTunnelRowType(
+                        new String[] {"id", "name"},
+                        new SeaTunnelDataType[] {BasicType.INT_TYPE, 
BasicType.STRING_TYPE});
+        MicrosoftModel model =
+                new MicrosoftModel(
+                        rowType,
+                        SqlType.STRING,
+                        null,
+                        "Determine whether someone is Chinese or American by 
their name",
+                        "gpt-35-turbo",
+                        "sk-xxx",
+                        
"https://api.moonshot.cn/openai/deployments/${model}/chat/completions?api-version=2024-02-01";);
+        Field apiPathField = model.getClass().getDeclaredField("apiPath");
+        apiPathField.setAccessible(true);
+        String apiPath = (String) apiPathField.get(model);
+        Assertions.assertEquals(
+                
"https://api.moonshot.cn/openai/deployments/gpt-35-turbo/chat/completions?api-version=2024-02-01";,
+                apiPath);
+
+        ObjectNode node =
+                model.createJsonNodeFromData(
+                        "Determine whether someone is Chinese or American by 
their name",
+                        "{\"id\":1, \"name\":\"John\"}");
+        Assertions.assertEquals(
+                "{\"messages\":[{\"role\":\"system\",\"content\":\"Determine 
whether someone is Chinese or American by their 
name\"},{\"role\":\"user\",\"content\":\"{\\\"id\\\":1, 
\\\"name\\\":\\\"John\\\"}\"}]}",
+                OBJECT_MAPPER.writeValueAsString(node));
+        model.close();
+    }
+
     @Test
     void testCustomRequestJson() throws IOException {
         SeaTunnelRowType rowType =

Reply via email to