fsk119 commented on code in PR #26652:
URL: https://github.com/apache/flink/pull/26652#discussion_r2139125807


##########
flink-models/pom.xml:
##########
@@ -0,0 +1,40 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+
+  http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied.  See the License for the
+specific language governing permissions and limitations
+under the License.
+-->
+<project xmlns="http://maven.apache.org/POM/4.0.0";
+                xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance";
+                xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 
http://maven.apache.org/maven-v4_0_0.xsd";>
+
+       <modelVersion>4.0.0</modelVersion>
+
+       <parent>
+               <groupId>org.apache.flink</groupId>
+               <artifactId>flink-parent</artifactId>
+               <version>2.1-SNAPSHOT</version>
+       </parent>
+
+       <artifactId>flink-models</artifactId>
+       <name>Flink : Models : </name>
+       <packaging>pom</packaging>
+
+       <modules>
+               <module>flink-model-openai</module>
+       </modules>
+

Review Comment:
   I think we can follow the same behaviour as flink-connectors pom.xml:
   * add some default dependencies, e.g. sl4j
   * enable dependency converage by default



##########
flink-models/flink-model-openai/pom.xml:
##########
@@ -0,0 +1,142 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+
+  http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied.  See the License for the
+specific language governing permissions and limitations
+under the License.
+-->
+<project xmlns="http://maven.apache.org/POM/4.0.0";
+                xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance";
+                xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 
http://maven.apache.org/maven-v4_0_0.xsd";>
+
+       <modelVersion>4.0.0</modelVersion>
+
+       <parent>
+               <groupId>org.apache.flink</groupId>
+               <artifactId>flink-models</artifactId>
+               <version>2.1-SNAPSHOT</version>
+       </parent>
+
+       <artifactId>flink-model-openai</artifactId>
+       <name>Flink : Models : OpenAI</name>
+
+       <properties>
+               <openai.version>1.6.1</openai.version>
+               <openai.kotlin.version>1.9.10</openai.kotlin.version>
+               <okhttp.version>4.12.0</okhttp.version>
+               <test.gson.version>2.11.0</test.gson.version>
+       </properties>
+
+       <dependencies>
+               <dependency>
+                       <groupId>com.openai</groupId>
+                       <artifactId>openai-java</artifactId>
+                       <version>${openai.version}</version>
+               </dependency>
+
+               <dependency>
+                       <groupId>com.squareup.okhttp3</groupId>
+                       <artifactId>okhttp</artifactId>
+                       <version>${okhttp.version}</version>
+               </dependency>
+
+               <dependency>
+                       <groupId>org.slf4j</groupId>
+                       <artifactId>slf4j-api</artifactId>
+                       <version>${slf4j.version}</version>
+               </dependency>
+               <dependency>
+                       <groupId>org.apache.flink</groupId>
+                       <artifactId>flink-core</artifactId>
+                       <version>${project.version}</version>
+               </dependency>
+               <dependency>
+                       <groupId>org.apache.flink</groupId>
+                       <artifactId>flink-table-api-java</artifactId>
+                       <version>${project.version}</version>
+               </dependency>
+               <dependency>
+                       <groupId>org.apache.flink</groupId>
+                       <artifactId>flink-table-common</artifactId>
+                       <version>${project.version}</version>
+               </dependency>
+
+               <dependency>
+                       <groupId>org.apache.flink</groupId>
+                       
<artifactId>flink-table-planner_${scala.binary.version}</artifactId>
+                       <version>${project.version}</version>
+                       <scope>test</scope>
+               </dependency>
+
+               <dependency>
+                       <groupId>com.squareup.okhttp3</groupId>
+                       <artifactId>mockwebserver</artifactId>
+                       <version>${okhttp.version}</version>
+                       <scope>test</scope>
+               </dependency>
+
+               <dependency>
+                       <groupId>com.google.code.gson</groupId>

Review Comment:
   Do we need this?



##########
flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIUtils.java:
##########
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.model.openai;
+
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.util.Preconditions;
+
+import com.openai.client.OpenAIClientAsync;
+import com.openai.client.okhttp.OpenAIOkHttpClientAsync;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.Map;
+import java.util.Objects;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicInteger;
+
+/** Utility class related to Open AI SDK. */
+public class OpenAIUtils {
+    private static final Logger LOG = 
LoggerFactory.getLogger(OpenAIUtils.class);
+
+    private static final Object LOCK = new Object();
+
+    private static final Map<ReferenceKey, ReferenceValue> cache = new 
ConcurrentHashMap<>();

Review Comment:
   nit: It seems the cache modification/read is under the lock protection. 
Maybe hashmap is enough.



##########
flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIEmbeddingModelFunction.java:
##########
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.model.openai;
+
+import org.apache.flink.configuration.ConfigOption;
+import org.apache.flink.configuration.ConfigOptions;
+import org.apache.flink.configuration.ReadableConfig;
+import org.apache.flink.table.data.GenericArrayData;
+import org.apache.flink.table.data.GenericRowData;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.factories.ModelProviderFactory;
+import org.apache.flink.table.functions.AsyncPredictFunction;
+import org.apache.flink.table.types.logical.ArrayType;
+import org.apache.flink.table.types.logical.FloatType;
+
+import com.openai.models.embeddings.CreateEmbeddingResponse;
+import com.openai.models.embeddings.EmbeddingCreateParams;
+import com.openai.models.embeddings.EmbeddingCreateParams.EncodingFormat;
+
+import javax.annotation.Nullable;
+
+import java.util.Collection;
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.stream.Collectors;
+
+/** {@link AsyncPredictFunction} for OpenAI embedding task. */
+public class OpenAIEmbeddingModelFunction extends AbstractOpenAIModelFunction {

Review Comment:
   add serialVersionUID



##########
flink-models/flink-model-openai/pom.xml:
##########
@@ -0,0 +1,142 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+
+  http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied.  See the License for the
+specific language governing permissions and limitations
+under the License.
+-->
+<project xmlns="http://maven.apache.org/POM/4.0.0";
+                xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance";
+                xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 
http://maven.apache.org/maven-v4_0_0.xsd";>
+
+       <modelVersion>4.0.0</modelVersion>
+
+       <parent>
+               <groupId>org.apache.flink</groupId>
+               <artifactId>flink-models</artifactId>
+               <version>2.1-SNAPSHOT</version>
+       </parent>
+
+       <artifactId>flink-model-openai</artifactId>
+       <name>Flink : Models : OpenAI</name>
+
+       <properties>
+               <openai.version>1.6.1</openai.version>
+               <openai.kotlin.version>1.9.10</openai.kotlin.version>
+               <okhttp.version>4.12.0</okhttp.version>

Review Comment:
   Can we use the same okhttp version that defined in the flink root pom.xml 
https://github.com/apache/flink/blob/master/pom.xml#L165



##########
flink-models/flink-model-openai/src/test/java/org/apache/flink/model/openai/OpenAIChatModelTest.java:
##########
@@ -0,0 +1,311 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.model.openai;
+
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.TableEnvironment;
+import org.apache.flink.table.api.TableResult;
+import org.apache.flink.table.api.internal.TableEnvironmentImpl;
+import org.apache.flink.table.catalog.CatalogManager;
+import org.apache.flink.table.catalog.CatalogModel;
+import org.apache.flink.table.catalog.ObjectIdentifier;
+import org.apache.flink.types.Row;
+
+import com.google.gson.JsonObject;
+import com.google.gson.JsonParser;
+import okhttp3.mockwebserver.Dispatcher;
+import okhttp3.mockwebserver.MockResponse;
+import okhttp3.mockwebserver.MockWebServer;
+import okhttp3.mockwebserver.RecordedRequest;
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.jupiter.api.AfterAll;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/** Test for {@link OpenAIChatModelFunction}. */
+public class OpenAIChatModelTest {
+    private static final String MODEL_NAME = "m";
+
+    private static final Schema INPUT_SCHEMA =
+            Schema.newBuilder().column("input", DataTypes.STRING()).build();
+    private static final Schema OUTPUT_SCHEMA =
+            Schema.newBuilder().column("content", DataTypes.STRING()).build();
+
+    private static MockWebServer server;
+
+    private Map<String, String> modelOptions;
+
+    private TableEnvironment tEnv;
+
+    @BeforeAll
+    public static void beforeAll() throws IOException {
+        server = new MockWebServer();
+        server.setDispatcher(new TestDispatcher());
+        server.start();
+    }
+
+    @AfterAll
+    public static void afterAll() throws IOException {
+        if (server != null) {
+            server.close();
+        }
+    }
+
+    @BeforeEach
+    public void setup() {
+        tEnv = TableEnvironment.create(new Configuration());
+        tEnv.executeSql(
+                "CREATE TABLE MyTable(input STRING, invalid_input DOUBLE) WITH 
( 'connector' = 'datagen', 'number-of-rows' = '10')");
+
+        modelOptions = new HashMap<>();
+        modelOptions.put("provider", "openai");
+        modelOptions.put("endpoint", 
server.url("/chat/completions").toString());
+        modelOptions.put("model", "qwen-turbo");
+        modelOptions.put("apiKey", "foobar");
+    }
+
+    @AfterEach
+    public void afterEach() {
+        assertThat(OpenAIUtils.getCache()).isEmpty();
+    }
+
+    @Test
+    public void testChat() {
+        CatalogManager catalogManager = ((TableEnvironmentImpl) 
tEnv).getCatalogManager();
+        catalogManager.createModel(
+                CatalogModel.of(INPUT_SCHEMA, OUTPUT_SCHEMA, modelOptions, 
"This is a new model."),
+                ObjectIdentifier.of(
+                        catalogManager.getCurrentCatalog(),
+                        catalogManager.getCurrentDatabase(),
+                        MODEL_NAME),
+                false);
+
+        TableResult tableResult =
+                tEnv.executeSql(
+                        String.format(
+                                "SELECT input, content FROM ML_PREDICT(TABLE 
MyTable, MODEL %s, DESCRIPTOR(`input`))",
+                                MODEL_NAME));
+        List<Row> result = IteratorUtils.toList(tableResult.collect());
+        assertThat(result).hasSize(10);
+        for (Row row : result) {
+            assertThat(row.getField(0)).isInstanceOf(String.class);
+            assertThat(row.getField(1)).isInstanceOf(String.class);
+            assertThat((String) row.getFieldAs(1)).isNotEmpty();

Review Comment:
     assertThat((String) row.getFieldAs(1))
                       .isEqualTo(
                               "This is a mocked response continuation 
continuation continuation continuation continuation continuation continuation 
continuation continuation continuation");
     



##########
flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/AbstractOpenAIModelFunction.java:
##########
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.model.openai;
+
+import org.apache.flink.configuration.ConfigOption;
+import org.apache.flink.configuration.ConfigOptions;
+import org.apache.flink.configuration.ReadableConfig;
+import org.apache.flink.configuration.description.Description;
+import org.apache.flink.table.api.config.ExecutionConfigOptions;
+import org.apache.flink.table.catalog.Column;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.factories.ModelProviderFactory;
+import org.apache.flink.table.functions.AsyncPredictFunction;
+import org.apache.flink.table.functions.FunctionContext;
+import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.table.types.logical.VarCharType;
+
+import com.openai.client.OpenAIClientAsync;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.List;
+import java.util.stream.Collectors;
+
+import static org.apache.flink.configuration.description.TextElement.code;
+
+/** Abstract parent class for {@link AsyncPredictFunction}s for OpenAI API. */
+public abstract class AbstractOpenAIModelFunction extends AsyncPredictFunction 
{
+    private static final Logger LOG = 
LoggerFactory.getLogger(AbstractOpenAIModelFunction.class);
+
+    public static final ConfigOption<String> ENDPOINT =
+            ConfigOptions.key("endpoint")
+                    .stringType()
+                    .noDefaultValue()
+                    .withDescription(
+                            Description.builder()
+                                    .text(
+                                            "Full URL of the OpenAI API 
endpoint, e.g., %s or %s",
+                                            
code("https://api.openai.com/v1/chat/completions";),
+                                            
code("https://api.openai.com/v1/embeddings";))
+                                    .build());
+
+    public static final ConfigOption<String> API_KEY =
+            ConfigOptions.key("apiKey")
+                    .stringType()
+                    .noDefaultValue()
+                    .withDescription("OpenAI API key for authentication.");
+
+    public static final ConfigOption<String> MODEL =
+            ConfigOptions.key("model")
+                    .stringType()
+                    .noDefaultValue()
+                    .withDescription(
+                            Description.builder()
+                                    .text(
+                                            "Model name, e.g., %s, %s.",
+                                            code("gpt-3.5-turbo"), 
code("text-embedding-ada-002"))
+                                    .build());
+
+    protected transient OpenAIClientAsync client;
+
+    private final int numRetry;
+    private final String baseUrl;
+    private final String apiKey;
+
+    public AbstractOpenAIModelFunction(
+            ModelProviderFactory.Context factoryContext, ReadableConfig 
config) {
+        String endpoint = config.get(ENDPOINT);
+        this.baseUrl = endpoint.replaceAll(String.format("/%s/*$", 
getEndpointSuffix()), "");
+        this.apiKey = config.get(API_KEY);
+        this.numRetry =

Review Comment:
   Add some comment to describe the behaviour, e.g.
   ```
   The model service enforces rate-limiting constraints, necessitating retry 
mechanisms in most operational scenarios. Within the asynchronous operator 
framework, the system is designed to process up to 
config.get(ExecutionConfigOptions.TABLE_EXEC_ASYNC_LOOKUP_BUFFER_CAPACITY) 
concurrent requests in parallel. To mitigate potential performance degradation 
from simultaneous requests, we implement a dynamic retry strategy where the 
maximum retry count is directly proportional to the configured parallelism 
level, ensuring robust error resilience while maintaining throughput efficiency.
   ```



##########
flink-models/flink-model-openai/src/test/java/org/apache/flink/model/openai/OpenAIChatModelTest.java:
##########
@@ -0,0 +1,311 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.model.openai;
+
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.TableEnvironment;
+import org.apache.flink.table.api.TableResult;
+import org.apache.flink.table.api.internal.TableEnvironmentImpl;
+import org.apache.flink.table.catalog.CatalogManager;
+import org.apache.flink.table.catalog.CatalogModel;
+import org.apache.flink.table.catalog.ObjectIdentifier;
+import org.apache.flink.types.Row;
+
+import com.google.gson.JsonObject;
+import com.google.gson.JsonParser;
+import okhttp3.mockwebserver.Dispatcher;
+import okhttp3.mockwebserver.MockResponse;
+import okhttp3.mockwebserver.MockWebServer;
+import okhttp3.mockwebserver.RecordedRequest;
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.jupiter.api.AfterAll;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/** Test for {@link OpenAIChatModelFunction}. */
+public class OpenAIChatModelTest {
+    private static final String MODEL_NAME = "m";
+
+    private static final Schema INPUT_SCHEMA =
+            Schema.newBuilder().column("input", DataTypes.STRING()).build();
+    private static final Schema OUTPUT_SCHEMA =
+            Schema.newBuilder().column("content", DataTypes.STRING()).build();
+
+    private static MockWebServer server;
+
+    private Map<String, String> modelOptions;
+
+    private TableEnvironment tEnv;
+
+    @BeforeAll
+    public static void beforeAll() throws IOException {
+        server = new MockWebServer();
+        server.setDispatcher(new TestDispatcher());
+        server.start();
+    }
+
+    @AfterAll
+    public static void afterAll() throws IOException {
+        if (server != null) {
+            server.close();
+        }
+    }
+
+    @BeforeEach
+    public void setup() {
+        tEnv = TableEnvironment.create(new Configuration());
+        tEnv.executeSql(
+                "CREATE TABLE MyTable(input STRING, invalid_input DOUBLE) WITH 
( 'connector' = 'datagen', 'number-of-rows' = '10')");
+
+        modelOptions = new HashMap<>();
+        modelOptions.put("provider", "openai");
+        modelOptions.put("endpoint", 
server.url("/chat/completions").toString());
+        modelOptions.put("model", "qwen-turbo");
+        modelOptions.put("apiKey", "foobar");
+    }
+
+    @AfterEach
+    public void afterEach() {
+        assertThat(OpenAIUtils.getCache()).isEmpty();
+    }
+
+    @Test
+    public void testChat() {
+        CatalogManager catalogManager = ((TableEnvironmentImpl) 
tEnv).getCatalogManager();
+        catalogManager.createModel(
+                CatalogModel.of(INPUT_SCHEMA, OUTPUT_SCHEMA, modelOptions, 
"This is a new model."),
+                ObjectIdentifier.of(
+                        catalogManager.getCurrentCatalog(),
+                        catalogManager.getCurrentDatabase(),
+                        MODEL_NAME),
+                false);
+
+        TableResult tableResult =
+                tEnv.executeSql(
+                        String.format(
+                                "SELECT input, content FROM ML_PREDICT(TABLE 
MyTable, MODEL %s, DESCRIPTOR(`input`))",
+                                MODEL_NAME));
+        List<Row> result = IteratorUtils.toList(tableResult.collect());
+        assertThat(result).hasSize(10);
+        for (Row row : result) {
+            assertThat(row.getField(0)).isInstanceOf(String.class);
+            assertThat(row.getField(1)).isInstanceOf(String.class);
+            assertThat((String) row.getFieldAs(1)).isNotEmpty();
+        }
+    }
+
+    @Test
+    public void testMaxToken() {
+        int maxTokens = 20;
+        CatalogManager catalogManager = ((TableEnvironmentImpl) 
tEnv).getCatalogManager();
+        Map<String, String> modelOptions = new HashMap<>(this.modelOptions);
+        modelOptions.put("maxTokens", Integer.toString(maxTokens));
+        catalogManager.createModel(
+                CatalogModel.of(INPUT_SCHEMA, OUTPUT_SCHEMA, modelOptions, 
"This is a new model."),
+                ObjectIdentifier.of(
+                        catalogManager.getCurrentCatalog(),
+                        catalogManager.getCurrentDatabase(),
+                        MODEL_NAME),
+                false);
+
+        TableResult tableResult =
+                tEnv.executeSql(
+                        String.format(
+                                "SELECT input, content FROM ML_PREDICT(TABLE 
MyTable, MODEL %s, DESCRIPTOR(`input`))",
+                                MODEL_NAME));
+        List<Row> result = IteratorUtils.toList(tableResult.collect());
+        assertThat(result).hasSize(10);
+        for (Row row : result) {
+            assertThat(row.getField(0)).isInstanceOf(String.class);
+            assertThat(row.getField(1)).isInstanceOf(String.class);
+            assertThat((String) row.getFieldAs(1)).isNotEmpty();
+            assertThat(((String) row.getFieldAs(1)).split(" 
")).hasSizeLessThan(maxTokens);
+        }
+    }
+
+    @Test
+    public void testStop() {
+        String stop = "a,the";
+        CatalogManager catalogManager = ((TableEnvironmentImpl) 
tEnv).getCatalogManager();
+        Map<String, String> modelOptions = new HashMap<>(this.modelOptions);
+        modelOptions.put("stop", stop);
+        catalogManager.createModel(
+                CatalogModel.of(INPUT_SCHEMA, OUTPUT_SCHEMA, modelOptions, 
"This is a new model."),
+                ObjectIdentifier.of(
+                        catalogManager.getCurrentCatalog(),
+                        catalogManager.getCurrentDatabase(),
+                        MODEL_NAME),
+                false);
+
+        TableResult tableResult =
+                tEnv.executeSql(
+                        String.format(
+                                "SELECT input, content FROM ML_PREDICT(TABLE 
MyTable, MODEL %s, DESCRIPTOR(`input`))",
+                                MODEL_NAME));
+        List<Row> result = IteratorUtils.toList(tableResult.collect());
+        assertThat(result).hasSize(10);
+        for (Row row : result) {
+            assertThat(row.getField(0)).isInstanceOf(String.class);
+            assertThat(row.getField(1)).isInstanceOf(String.class);
+            assertThat((String) row.getFieldAs(1)).isNotEmpty();
+            assertThat(((String) row.getFieldAs(1)).split(" "))
+                    .doesNotContain("a")
+                    .doesNotContain("the");
+        }
+    }
+
+    @Test
+    public void testInvalidInputSchema() {
+        CatalogManager catalogManager = ((TableEnvironmentImpl) 
tEnv).getCatalogManager();
+        ObjectIdentifier modelIdentifier =
+                ObjectIdentifier.of(
+                        catalogManager.getCurrentCatalog(),
+                        catalogManager.getCurrentDatabase(),
+                        MODEL_NAME);
+
+        Schema inputSchemaWithInvalidColumnType =
+                Schema.newBuilder().column("input", 
DataTypes.DOUBLE()).build();
+
+        catalogManager.createModel(
+                CatalogModel.of(
+                        inputSchemaWithInvalidColumnType,
+                        OUTPUT_SCHEMA,
+                        modelOptions,
+                        "This is a new model."),
+                modelIdentifier,
+                false);
+        assertThatThrownBy(
+                        () ->
+                                tEnv.executeSql(
+                                        String.format(
+                                                "SELECT * FROM 
TABLE(ML_PREDICT(TABLE MyTable, MODEL %s, DESCRIPTOR(`invalid_input`)))",
+                                                MODEL_NAME)))
+                .rootCause()
+                .isInstanceOf(IllegalArgumentException.class)
+                .hasMessageContainingAll("input", "DOUBLE", "STRING");
+    }
+
+    @Test
+    public void testInvalidOutputSchema() {
+        CatalogManager catalogManager = ((TableEnvironmentImpl) 
tEnv).getCatalogManager();
+        ObjectIdentifier modelIdentifier =
+                ObjectIdentifier.of(
+                        catalogManager.getCurrentCatalog(),
+                        catalogManager.getCurrentDatabase(),
+                        MODEL_NAME);
+
+        Schema outputSchemaWithInvalidColumnType =
+                Schema.newBuilder().column("output", 
DataTypes.DOUBLE()).build();
+
+        catalogManager.createModel(
+                CatalogModel.of(
+                        INPUT_SCHEMA,
+                        outputSchemaWithInvalidColumnType,
+                        modelOptions,
+                        "This is a new model."),
+                modelIdentifier,
+                false);
+        assertThatThrownBy(
+                        () ->
+                                tEnv.executeSql(
+                                        String.format(
+                                                "SELECT * FROM 
TABLE(ML_PREDICT(TABLE MyTable, MODEL %s, DESCRIPTOR(`input`)))",
+                                                MODEL_NAME)))
+                .rootCause()
+                .isInstanceOf(IllegalArgumentException.class)
+                .hasMessageContainingAll("output", "DOUBLE", "STRING");
+    }
+
+    private static class TestDispatcher extends Dispatcher {
+        @Override
+        public MockResponse dispatch(RecordedRequest request) {

Review Comment:
   I think we can do like this:
   
   ```
           public MockResponse dispatch(RecordedRequest request) {
               String path = request.getRequestUrl().encodedPath();
   
               String body = request.getBody().readUtf8();
   
               if (!path.endsWith("/chat/completions")) {
                   return new MockResponse().setResponseCode(404);
               }
   
               try {
                   JsonNode root = OBJECT_MAPPER.readTree(body);
                   int maxTokens = root.has("max_tokens") ? 
root.get("max_tokens").asInt() : 16;
                   List<String> stop = new ArrayList<>();
                   if (root.has("stop")) {
                       root.get("stop").forEach(node -> 
stop.add(node.asText()));
                   }
   
                   StringBuilder contentBuilder = new StringBuilder("This is a 
mocked response");
                   contentBuilder.append(" continuation".repeat(Math.max(0, 
maxTokens - 6)));
                   for (String stopWord : stop) {
                       if (contentBuilder.toString().contains(stopWord)) {
                           int stopIndex = contentBuilder.indexOf(stopWord);
                           if (stopIndex > 0) {
                               contentBuilder.delete(stopIndex, 
contentBuilder.length());
                           }
                       }
                   }
   
                   String responseBody =
                           "{"
                                   + "  \"id\": \"chatcmpl-1234567890ABCD\","
                                   + "  \"object\": \"chat.completion\","
                                   + "  \"created\": 1717029203,"
                                   + "  \"model\": \"gpt-3.5-turbo-0125\","
                                   + "  \"choices\": [{"
                                   + "    \"index\": 0,"
                                   + "    \"message\": {"
                                   + "      \"role\": \"assistant\","
                                   + "      \"content\": \""
                                   + contentBuilder
                                   + "\""
                                   + "    },"
                                   + "    \"finish_reason\": \"stop\""
                                   + "  }],"
                                   + "  \"usage\": {"
                                   + "    \"prompt_tokens\": 9,"
                                   + "    \"completion_tokens\": "
                                   + Math.min(maxTokens, 100)
                                   + ","
                                   + "    \"total_tokens\": "
                                   + (9 + Math.min(maxTokens, 100))
                                   + "  }"
                                   + "}";
   
                   return new MockResponse()
                           .setHeader("Content-Type", "application/json")
                           .setBody(responseBody);
               } catch (Exception e) {
                   throw new RuntimeException(e);
               }
           }
   ```



##########
flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIChatModelFunction.java:
##########
@@ -0,0 +1,143 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.model.openai;
+
+import org.apache.flink.configuration.ConfigOption;
+import org.apache.flink.configuration.ConfigOptions;
+import org.apache.flink.configuration.ReadableConfig;
+import org.apache.flink.table.data.GenericRowData;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.data.binary.BinaryStringData;
+import org.apache.flink.table.factories.ModelProviderFactory;
+import org.apache.flink.table.functions.AsyncPredictFunction;
+import org.apache.flink.table.types.logical.VarCharType;
+
+import com.openai.models.chat.completions.ChatCompletion;
+import com.openai.models.chat.completions.ChatCompletionCreateParams;
+
+import javax.annotation.Nullable;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.stream.Collectors;
+
+/** {@link AsyncPredictFunction} for OpenAI chat completion task. */
+public class OpenAIChatModelFunction extends AbstractOpenAIModelFunction {

Review Comment:
   add serialVersionUID



##########
docs/content.zh/docs/dev/table/models/_index.md:
##########
@@ -0,0 +1,23 @@
+---
+title: 模型

Review Comment:
   It's better we can move the doc to here.
   
   
![image](https://github.com/user-attachments/assets/bab91987-f4cd-4351-96a1-15855802d199)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to