This is an automated email from the ASF dual-hosted git repository. liugddx pushed a commit to branch dev in repository https://gitbox.apache.org/repos/asf/seatunnel.git
The following commit(s) were added to refs/heads/dev by this push: new 5b5ee84130 [Improve][Transform] Add LLM model provider microsoft (#7778) 5b5ee84130 is described below commit 5b5ee8413052730cf5f361b701465b6f4fbfb219 Author: corgy-w <73771213+corg...@users.noreply.github.com> AuthorDate: Tue Oct 8 09:35:23 2024 +0800 [Improve][Transform] Add LLM model provider microsoft (#7778) Co-authored-by: Jia Fan <fanjiaemi...@qq.com> --- docs/en/transform-v2/llm.md | 7 +- docs/zh/transform-v2/llm.md | 4 +- .../apache/seatunnel/e2e/transform/TestLLMIT.java | 7 ++ .../test/resources/llm_microsoft_transform.conf | 75 +++++++++++++++ .../src/test/resources/mockserver-config.json | 32 +++++++ .../transform/nlpmodel/ModelProvider.java | 1 + .../transform/nlpmodel/llm/LLMTransform.java | 12 +++ .../nlpmodel/llm/LLMTransformFactory.java | 8 +- .../llm/remote/microsoft/MicrosoftModel.java | 103 +++++++++++++++++++++ .../transform/llm/LLMRequestJsonTest.java | 34 +++++++ 10 files changed, 277 insertions(+), 6 deletions(-) diff --git a/docs/en/transform-v2/llm.md b/docs/en/transform-v2/llm.md index 8ee5a36a9a..81dc9b3c70 100644 --- a/docs/en/transform-v2/llm.md +++ b/docs/en/transform-v2/llm.md @@ -11,7 +11,7 @@ more. ## Options | name | type | required | default value | -|------------------------| ------ | -------- |---------------| +|------------------------|--------|----------|---------------| | model_provider | enum | yes | | | output_data_type | enum | no | String | | output_column_name | string | no | llm_output | @@ -28,7 +28,9 @@ more. ### model_provider The model provider to use. The available options are: -OPENAI, DOUBAO, KIMIAI, CUSTOM +OPENAI, DOUBAO, KIMIAI, MICROSOFT, CUSTOM + +> tips: If you use Microsoft, please make sure api_path cannot be empty ### output_data_type @@ -254,6 +256,7 @@ sink { } } ``` + ### Customize the LLM model ```hocon diff --git a/docs/zh/transform-v2/llm.md b/docs/zh/transform-v2/llm.md index c6f7aeefea..5ab37f5870 100644 --- a/docs/zh/transform-v2/llm.md +++ b/docs/zh/transform-v2/llm.md @@ -26,7 +26,9 @@ ### model_provider 要使用的模型提供者。可用选项为: -OPENAI、DOUBAO、KIMIAI、CUSTOM +OPENAI、DOUBAO、KIMIAI、MICROSOFT, CUSTOM + +> tips: 如果使用 Microsoft, 请确保 api_path 配置不能为空 ### output_data_type diff --git a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/java/org/apache/seatunnel/e2e/transform/TestLLMIT.java b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/java/org/apache/seatunnel/e2e/transform/TestLLMIT.java index d98a5e7e33..f739e7af96 100644 --- a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/java/org/apache/seatunnel/e2e/transform/TestLLMIT.java +++ b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/java/org/apache/seatunnel/e2e/transform/TestLLMIT.java @@ -88,6 +88,13 @@ public class TestLLMIT extends TestSuiteBase implements TestResource { Assertions.assertEquals(0, execResult.getExitCode()); } + @TestTemplate + public void testLLMWithMicrosoft(TestContainer container) + throws IOException, InterruptedException { + Container.ExecResult execResult = container.executeJob("/llm_microsoft_transform.conf"); + Assertions.assertEquals(0, execResult.getExitCode()); + } + @TestTemplate public void testLLMWithOpenAIBoolean(TestContainer container) throws IOException, InterruptedException { diff --git a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/llm_microsoft_transform.conf b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/llm_microsoft_transform.conf new file mode 100644 index 0000000000..37205a3aca --- /dev/null +++ b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/llm_microsoft_transform.conf @@ -0,0 +1,75 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +###### +###### This config file is a demonstration of streaming processing in seatunnel config +###### + +env { + job.mode = "BATCH" +} + +source { + FakeSource { + row.num = 5 + schema = { + fields { + id = "int" + name = "string" + } + } + rows = [ + {fields = [1, "Jia Fan"], kind = INSERT} + {fields = [2, "Hailin Wang"], kind = INSERT} + {fields = [3, "Tomas"], kind = INSERT} + {fields = [4, "Eric"], kind = INSERT} + {fields = [5, "Guangdong Liu"], kind = INSERT} + ] + result_table_name = "fake" + } +} + +transform { + LLM { + source_table_name = "fake" + model_provider = MICROSOFT + model = gpt-35-turbo + api_key = sk-xxx + prompt = "Determine whether someone is Chinese or American by their name" + api_path = "http://mockserver:1080/openai/deployments/${model}/chat/completions?api-version=2024-02-01" + result_table_name = "llm_output" + } +} + +sink { + Assert { + source_table_name = "llm_output" + rules = + { + field_rules = [ + { + field_name = llm_output + field_type = string + field_value = [ + { + rule_type = NOT_NULL + } + ] + } + ] + } + } +} \ No newline at end of file diff --git a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/mockserver-config.json b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/mockserver-config.json index 44dd94396e..ffdb409c9c 100644 --- a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/mockserver-config.json +++ b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/mockserver-config.json @@ -104,5 +104,37 @@ "Content-Type": "application/json" } } + }, + { + "httpRequest": { + "method": "POST", + "path": "/openai/deployments/gpt-35-turbo/chat/.*" + }, + "httpResponse": { + "body": { + "id": "chatcmpl-6v7mkQj980V1yBec6ETrKPRqFjNw9", + "object": "chat.completion", + "created": 1679072642, + "model": "gpt-35-turbo", + "usage": { + "prompt_tokens": 58, + "completion_tokens": 68, + "total_tokens": 126 + }, + "choices": [ + { + "message": { + "role": "assistant", + "content": "[\"Chinese\"]" + }, + "finish_reason": "stop", + "index": 0 + } + ] + }, + "headers": { + "Content-Type": "application/json" + } + } } ] diff --git a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelProvider.java b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelProvider.java index ce22bc5a6d..3172137706 100644 --- a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelProvider.java +++ b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelProvider.java @@ -26,6 +26,7 @@ public enum ModelProvider { "https://ark.cn-beijing.volces.com/api/v3/embeddings"), QIANFAN("", "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings"), KIMIAI("https://api.moonshot.cn/v1/chat/completions", ""), + MICROSOFT("", ""), CUSTOM("", ""), LOCAL("", ""); diff --git a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransform.java b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransform.java index 08ae42e443..069945951b 100644 --- a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransform.java +++ b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransform.java @@ -31,6 +31,7 @@ import org.apache.seatunnel.transform.nlpmodel.ModelTransformConfig; import org.apache.seatunnel.transform.nlpmodel.llm.remote.Model; import org.apache.seatunnel.transform.nlpmodel.llm.remote.custom.CustomModel; import org.apache.seatunnel.transform.nlpmodel.llm.remote.kimiai.KimiAIModel; +import org.apache.seatunnel.transform.nlpmodel.llm.remote.microsoft.MicrosoftModel; import org.apache.seatunnel.transform.nlpmodel.llm.remote.openai.OpenAIModel; import lombok.NonNull; @@ -94,6 +95,17 @@ public class LLMTransform extends SingleFieldOutputTransform { LLMTransformConfig.CustomRequestConfig .CUSTOM_RESPONSE_PARSE)); break; + case MICROSOFT: + model = + new MicrosoftModel( + inputCatalogTable.getSeaTunnelRowType(), + outputDataType.getSqlType(), + config.get(LLMTransformConfig.INFERENCE_COLUMNS), + config.get(LLMTransformConfig.PROMPT), + config.get(LLMTransformConfig.MODEL), + config.get(LLMTransformConfig.API_KEY), + provider.usedLLMPath(config.get(LLMTransformConfig.API_PATH))); + break; case OPENAI: case DOUBAO: model = diff --git a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransformFactory.java b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransformFactory.java index eda57e1275..834c0b4d17 100644 --- a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransformFactory.java +++ b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransformFactory.java @@ -26,7 +26,6 @@ import org.apache.seatunnel.api.table.factory.Factory; import org.apache.seatunnel.api.table.factory.TableTransformFactory; import org.apache.seatunnel.api.table.factory.TableTransformFactoryContext; import org.apache.seatunnel.transform.nlpmodel.ModelProvider; -import org.apache.seatunnel.transform.nlpmodel.ModelTransformConfig; import com.google.auto.service.AutoService; @@ -50,14 +49,17 @@ public class LLMTransformFactory implements TableTransformFactory { LLMTransformConfig.PROCESS_BATCH_SIZE) .conditional( LLMTransformConfig.MODEL_PROVIDER, - Lists.newArrayList(ModelProvider.OPENAI, ModelProvider.DOUBAO), + Lists.newArrayList( + ModelProvider.OPENAI, + ModelProvider.DOUBAO, + ModelProvider.MICROSOFT), LLMTransformConfig.API_KEY) .conditional( LLMTransformConfig.MODEL_PROVIDER, ModelProvider.QIANFAN, LLMTransformConfig.API_KEY, LLMTransformConfig.SECRET_KEY, - ModelTransformConfig.OAUTH_PATH) + LLMTransformConfig.OAUTH_PATH) .conditional( LLMTransformConfig.MODEL_PROVIDER, ModelProvider.CUSTOM, diff --git a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/microsoft/MicrosoftModel.java b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/microsoft/MicrosoftModel.java new file mode 100644 index 0000000000..b6362c41a3 --- /dev/null +++ b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/microsoft/MicrosoftModel.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.seatunnel.transform.nlpmodel.llm.remote.microsoft; + +import org.apache.seatunnel.shade.com.fasterxml.jackson.core.type.TypeReference; +import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.JsonNode; +import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.ArrayNode; +import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.ObjectNode; + +import org.apache.seatunnel.api.table.type.SeaTunnelRowType; +import org.apache.seatunnel.api.table.type.SqlType; +import org.apache.seatunnel.transform.nlpmodel.CustomConfigPlaceholder; +import org.apache.seatunnel.transform.nlpmodel.llm.remote.AbstractModel; + +import org.apache.http.client.config.RequestConfig; +import org.apache.http.client.methods.CloseableHttpResponse; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.StringEntity; +import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.impl.client.HttpClients; +import org.apache.http.util.EntityUtils; + +import com.google.common.annotations.VisibleForTesting; + +import java.io.IOException; +import java.util.List; + +public class MicrosoftModel extends AbstractModel { + + private final CloseableHttpClient client; + private final String apiKey; + private final String model; + private final String apiPath; + + public MicrosoftModel( + SeaTunnelRowType rowType, + SqlType outputType, + List<String> projectionColumns, + String prompt, + String model, + String apiKey, + String apiPath) { + super(rowType, outputType, projectionColumns, prompt); + this.model = model; + this.apiKey = apiKey; + this.apiPath = + CustomConfigPlaceholder.replacePlaceholders( + apiPath, CustomConfigPlaceholder.REPLACE_PLACEHOLDER_MODEL, model, null); + this.client = HttpClients.createDefault(); + } + + @Override + protected List<String> chatWithModel(String prompt, String data) throws IOException { + HttpPost post = new HttpPost(apiPath); + post.setHeader("Authorization", "Bearer " + apiKey); + post.setHeader("Content-Type", "application/json"); + ObjectNode objectNode = createJsonNodeFromData(prompt, data); + post.setEntity(new StringEntity(OBJECT_MAPPER.writeValueAsString(objectNode), "UTF-8")); + post.setConfig( + RequestConfig.custom().setConnectTimeout(20000).setSocketTimeout(20000).build()); + CloseableHttpResponse response = client.execute(post); + String responseStr = EntityUtils.toString(response.getEntity()); + if (response.getStatusLine().getStatusCode() != 200) { + throw new IOException("Failed to chat with model, response: " + responseStr); + } + + JsonNode result = OBJECT_MAPPER.readTree(responseStr); + String resultData = result.get("choices").get(0).get("message").get("content").asText(); + return OBJECT_MAPPER.readValue( + convertData(resultData), new TypeReference<List<String>>() {}); + } + + @VisibleForTesting + public ObjectNode createJsonNodeFromData(String prompt, String data) { + ObjectNode objectNode = OBJECT_MAPPER.createObjectNode(); + ArrayNode messages = objectNode.putArray("messages"); + messages.addObject().put("role", "system").put("content", prompt); + messages.addObject().put("role", "user").put("content", data); + return objectNode; + } + + @Override + public void close() throws IOException { + if (client != null) { + client.close(); + } + } +} diff --git a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/llm/LLMRequestJsonTest.java b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/llm/LLMRequestJsonTest.java index 91666c4139..870af980fe 100644 --- a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/llm/LLMRequestJsonTest.java +++ b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/llm/LLMRequestJsonTest.java @@ -28,6 +28,7 @@ import org.apache.seatunnel.api.table.type.SqlType; import org.apache.seatunnel.format.json.RowToJsonConverters; import org.apache.seatunnel.transform.nlpmodel.llm.remote.custom.CustomModel; import org.apache.seatunnel.transform.nlpmodel.llm.remote.kimiai.KimiAIModel; +import org.apache.seatunnel.transform.nlpmodel.llm.remote.microsoft.MicrosoftModel; import org.apache.seatunnel.transform.nlpmodel.llm.remote.openai.OpenAIModel; import org.junit.jupiter.api.Assertions; @@ -36,6 +37,7 @@ import org.junit.jupiter.api.Test; import com.google.common.collect.Lists; import java.io.IOException; +import java.lang.reflect.Field; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -130,6 +132,38 @@ public class LLMRequestJsonTest { model.close(); } + @Test + void testMicrosoftRequestJson() throws Exception { + SeaTunnelRowType rowType = + new SeaTunnelRowType( + new String[] {"id", "name"}, + new SeaTunnelDataType[] {BasicType.INT_TYPE, BasicType.STRING_TYPE}); + MicrosoftModel model = + new MicrosoftModel( + rowType, + SqlType.STRING, + null, + "Determine whether someone is Chinese or American by their name", + "gpt-35-turbo", + "sk-xxx", + "https://api.moonshot.cn/openai/deployments/${model}/chat/completions?api-version=2024-02-01"); + Field apiPathField = model.getClass().getDeclaredField("apiPath"); + apiPathField.setAccessible(true); + String apiPath = (String) apiPathField.get(model); + Assertions.assertEquals( + "https://api.moonshot.cn/openai/deployments/gpt-35-turbo/chat/completions?api-version=2024-02-01", + apiPath); + + ObjectNode node = + model.createJsonNodeFromData( + "Determine whether someone is Chinese or American by their name", + "{\"id\":1, \"name\":\"John\"}"); + Assertions.assertEquals( + "{\"messages\":[{\"role\":\"system\",\"content\":\"Determine whether someone is Chinese or American by their name\"},{\"role\":\"user\",\"content\":\"{\\\"id\\\":1, \\\"name\\\":\\\"John\\\"}\"}]}", + OBJECT_MAPPER.writeValueAsString(node)); + model.close(); + } + @Test void testCustomRequestJson() throws IOException { SeaTunnelRowType rowType =