This is an automated email from the ASF dual-hosted git repository.

stigahuang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/impala.git

commit 1c6ff5f98d010ad81601aaa60662dfb7bec3b5d9
Author: Yida Wu <[email protected]>
AuthorDate: Mon Mar 3 11:08:53 2025 -0800

    IMPALA-13812: Fail query for certain errors related to AI functions
    
    The ai_generate_text() and ai_generate_text_default() functions
    return error message as a result (string) which could be
    misleading in some cases. This patch fixes this issue by setting
    the error in the context as a udf error, causing the query to
    fail in cases of configuration related errors or http errors
    when accessing the AI endpoint.
    
    Tests:
    Ran core tests.
    Added custom testcase TestAIGenerateText for failure cases
    with ai_generate_text_default().
    Added testcase TestExprs.test_ai_generate_text_exprs for
    failure cases with ai_generate_text().
    
    Change-Id: I639e48e64d62f7990cf9a3c35a59a0ee3a2c64e0
    Reviewed-on: http://gerrit.cloudera.org:8080/22588
    Reviewed-by: Yida Wu <[email protected]>
    Tested-by: Impala Public Jenkins <[email protected]>
---
 be/src/exprs/ai-functions-ir.cc                    |  2 +
 be/src/exprs/ai-functions.cc                       | 63 +++++++++++------
 be/src/exprs/ai-functions.h                        |  1 +
 be/src/exprs/expr-test.cc                          | 66 ++++++++++++++----
 .../org/apache/impala/service/JniFrontend.java     |  4 +-
 .../queries/QueryTest/ai_generate_text_exprs.test  | 79 ++++++++++++++++++++++
 .../functional-query/queries/QueryTest/exprs.test  | 20 ------
 tests/custom_cluster/test_ai_generate_text.py      | 67 ++++++++++++++++++
 tests/query_test/test_exprs.py                     |  6 ++
 9 files changed, 254 insertions(+), 54 deletions(-)

diff --git a/be/src/exprs/ai-functions-ir.cc b/be/src/exprs/ai-functions-ir.cc
index 2b36e268b..6f4772caf 100644
--- a/be/src/exprs/ai-functions-ir.cc
+++ b/be/src/exprs/ai-functions-ir.cc
@@ -51,6 +51,8 @@ const string 
AiFunctions::AI_GENERATE_TXT_MSG_OVERRIDE_FORBIDDEN_ERROR =
     "Invalid override, 'messages' cannot be overriden";
 const string AiFunctions::AI_GENERATE_TXT_N_OVERRIDE_FORBIDDEN_ERROR =
     "Invalid override, 'n' must be of integer type and have value 1";
+const string AiFunctions::AI_GENERATE_TXT_COMMON_ERROR_PREFIX =
+    "AI Generate Text Error: ";
 string AiFunctions::ai_api_key_;
 const char* AiFunctions::OPEN_AI_REQUEST_FIELD_CONTENT_TYPE_HEADER =
     "Content-Type: application/json";
diff --git a/be/src/exprs/ai-functions.cc b/be/src/exprs/ai-functions.cc
index 2e7dd6617..1ef2b079d 100644
--- a/be/src/exprs/ai-functions.cc
+++ b/be/src/exprs/ai-functions.cc
@@ -58,12 +58,20 @@ TAG_FLAG(ai_api_key_jceks_secret, sensitive);
 
 namespace impala {
 
-#define RETURN_STRINGVAL_IF_ERROR(ctx, stmt)               \
-  do {                                                     \
-    const ::impala::Status& _status = (stmt);              \
-    if (UNLIKELY(!_status.ok())) {                         \
-      return copyErrorMessage(ctx, _status.msg().msg());   \
-    }                                                      \
+// Set an error message in the context, causing the query to fail.
+#define SET_ERROR(ctx, status_str, prefix)          \
+  do {                                              \
+    (ctx)->SetError((prefix + status_str).c_str()); \
+  } while (false)
+
+// Check the status and return an error if it fails.
+#define RETURN_STRINGVAL_IF_ERROR(ctx, stmt)                                   
 \
+  do {                                                                         
 \
+    const ::impala::Status& _status = (stmt);                                  
 \
+    if (UNLIKELY(!_status.ok())) {                                             
 \
+      SET_ERROR(ctx, _status.msg().msg(), 
AI_GENERATE_TXT_COMMON_ERROR_PREFIX); \
+      return StringVal::null();                                                
 \
+    }                                                                          
 \
   } while (false)
 
 // Impala Ai Functions Options Constants.
@@ -181,8 +189,11 @@ StringVal 
AiFunctions::AiGenerateTextInternal(FunctionContext* ctx,
     try {
       ParseImpalaOptions(impala_options, impala_options_document, ai_options);
     } catch (const std::runtime_error& e) {
-      LOG(WARNING) << AI_GENERATE_TXT_JSON_PARSE_ERROR << ": " << e.what();
-      return StringVal(AI_GENERATE_TXT_JSON_PARSE_ERROR.c_str());
+      std::stringstream ss;
+      ss << AI_GENERATE_TXT_JSON_PARSE_ERROR << ": " << e.what();
+      LOG(WARNING) << ss.str();
+      const Status err_status(ss.str());
+      RETURN_STRINGVAL_IF_ERROR(ctx, err_status);
     }
   }
 
@@ -238,6 +249,9 @@ StringVal 
AiFunctions::AiGenerateTextInternal(FunctionContext* ctx,
     Value message(rapidjson::kObjectType);
     message.AddMember("role", "user", payload_allocator);
     if (prompt.ptr == nullptr || prompt.len == 0) {
+      // Return a string with the invalid prompt error message instead of 
failing
+      // the query, as the issue may be with the row rather than the 
configuration
+      // or query. This behavior might be reconsidered later.
       return StringVal(AI_GENERATE_TXT_INVALID_PROMPT_ERROR.c_str());
     }
     message.AddMember("content",
@@ -253,26 +267,32 @@ StringVal 
AiFunctions::AiGenerateTextInternal(FunctionContext* ctx,
     if (!fastpath && platform_params.ptr != nullptr && platform_params.len != 
0) {
       overrides.Parse(reinterpret_cast<char*>(platform_params.ptr), 
platform_params.len);
       if (overrides.HasParseError()) {
-        LOG(WARNING) << AI_GENERATE_TXT_JSON_PARSE_ERROR << ": error code "
-                     << overrides.GetParseError() << ", offset input "
-                     << overrides.GetErrorOffset();
-        return StringVal(AI_GENERATE_TXT_JSON_PARSE_ERROR.c_str());
+        std::stringstream ss;
+        ss << AI_GENERATE_TXT_JSON_PARSE_ERROR << ": error code "
+           << overrides.GetParseError() << ", offset input "
+           << overrides.GetErrorOffset();
+        LOG(WARNING) << ss.str();
+        const Status err_status(ss.str());
+        RETURN_STRINGVAL_IF_ERROR(ctx, err_status);
       }
       for (auto& m : overrides.GetObject()) {
         if (payload.HasMember(m.name.GetString())) {
           if (m.name == "messages") {
-            LOG(WARNING)
-                << AI_GENERATE_TXT_JSON_PARSE_ERROR
-                << ": 'messages' is constructed from 'prompt', cannot be 
overridden";
-            return 
StringVal(AI_GENERATE_TXT_MSG_OVERRIDE_FORBIDDEN_ERROR.c_str());
+            const string error_msg = 
AI_GENERATE_TXT_MSG_OVERRIDE_FORBIDDEN_ERROR
+                + ": 'messages' is constructed from 'prompt', cannot be 
overridden";
+            LOG(WARNING) << error_msg;
+            const Status err_status(error_msg);
+            RETURN_STRINGVAL_IF_ERROR(ctx, err_status);
           } else {
             payload[m.name.GetString()] = m.value;
           }
         } else {
           if ((m.name == "n") && !(m.value.IsInt() && m.value.GetInt() == 1)) {
-            LOG(WARNING) << AI_GENERATE_TXT_JSON_PARSE_ERROR
-                         << ": 'n' must be of integer type and have value 1";
-            return 
StringVal(AI_GENERATE_TXT_N_OVERRIDE_FORBIDDEN_ERROR.c_str());
+            const string error_msg = AI_GENERATE_TXT_N_OVERRIDE_FORBIDDEN_ERROR
+                + ": 'n' must be of integer type and have value 1";
+            LOG(WARNING) << error_msg;
+            const Status err_status(error_msg);
+            RETURN_STRINGVAL_IF_ERROR(ctx, err_status);
           }
           payload.AddMember(m.name, m.value, payload_allocator);
         }
@@ -312,7 +332,10 @@ StringVal 
AiFunctions::AiGenerateTextInternal(FunctionContext* ctx,
     status = curl.PostToURL(endpoint_str, payload_str, &resp, headers);
   }
   VLOG(2) << "AI Generate Text: \noriginal response: " << resp.ToString();
-  if (UNLIKELY(!status.ok())) return copyErrorMessage(ctx, status.ToString());
+  if (UNLIKELY(!status.ok())) {
+    SET_ERROR(ctx, status.ToString(), AI_GENERATE_TXT_COMMON_ERROR_PREFIX);
+    return StringVal::null();
+  }
   // Parse the JSON response string
   std::string response = AiGenerateTextParseOpenAiResponse(
       std::string_view(reinterpret_cast<char*>(resp.data()), resp.size()));
diff --git a/be/src/exprs/ai-functions.h b/be/src/exprs/ai-functions.h
index d72c3b54c..85a6e1b48 100644
--- a/be/src/exprs/ai-functions.h
+++ b/be/src/exprs/ai-functions.h
@@ -36,6 +36,7 @@ class AiFunctions {
   static const string AI_GENERATE_TXT_INVALID_PROMPT_ERROR;
   static const string AI_GENERATE_TXT_MSG_OVERRIDE_FORBIDDEN_ERROR;
   static const string AI_GENERATE_TXT_N_OVERRIDE_FORBIDDEN_ERROR;
+  static const string AI_GENERATE_TXT_COMMON_ERROR_PREFIX;
   static const char* OPEN_AI_REQUEST_FIELD_CONTENT_TYPE_HEADER;
   static const char* OPEN_AI_REQUEST_AUTH_HEADER;
   static const char* AZURE_OPEN_AI_REQUEST_AUTH_HEADER;
diff --git a/be/src/exprs/expr-test.cc b/be/src/exprs/expr-test.cc
index 0662354c3..e0a117a96 100644
--- a/be/src/exprs/expr-test.cc
+++ b/be/src/exprs/expr-test.cc
@@ -1259,6 +1259,16 @@ class ExprTest : public 
testing::TestWithParam<std::tuple<bool, bool>> {
         UdfTestHarness::CreateTestContext(return_type, arg_types, state, 
pool));
   }
 
+  // Helper function to close context then create a new one.
+  void RecreateUdfTestContext(const FunctionContext::TypeDesc& return_type,
+      const std::vector<FunctionContext::TypeDesc>& arg_types, RuntimeState* 
state,
+      MemPool* pool, FunctionContext** old_ctx) {
+    ASSERT_TRUE(old_ctx != nullptr && *old_ctx != nullptr);
+    UdfTestHarness::CloseContext(*old_ctx);
+    *old_ctx = CreateUdfTestContext(return_type, arg_types, state, pool);
+    ASSERT_TRUE(*old_ctx != nullptr);
+  }
+
   void TestBytes();
 };
 
@@ -11388,14 +11398,26 @@ TEST_P(ExprTest, AiFunctionsTest) {
   EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
       AiFunctions::AI_GENERATE_TXT_INVALID_PROMPT_ERROR);
 
+  size_t json_parse_error_len = 
AiFunctions::AI_GENERATE_TXT_JSON_PARSE_ERROR.size()
+      + AiFunctions::AI_GENERATE_TXT_COMMON_ERROR_PREFIX.size();
+  size_t override_forbidden_error_len =
+      AiFunctions::AI_GENERATE_TXT_MSG_OVERRIDE_FORBIDDEN_ERROR.size()
+      + AiFunctions::AI_GENERATE_TXT_COMMON_ERROR_PREFIX.size();
+  size_t n_override_forbidden_error_len =
+      AiFunctions::AI_GENERATE_TXT_N_OVERRIDE_FORBIDDEN_ERROR.size()
+      + AiFunctions::AI_GENERATE_TXT_COMMON_ERROR_PREFIX.size();
   // Test override/additional params
   // invalid json results in error.
   StringVal invalid_json_params("{\"temperature\": 0.49, \"stop\": 
[\"*\",::,]}");
   result = AiFunctions::AiGenerateTextInternal<false, 
AiFunctions::AI_PLATFORM::OPEN_AI>(
       ctx, openai_endpoint, prompt, model, jceks_secret, invalid_json_params,
       impala_options, dry_run);
-  EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
-      AiFunctions::AI_GENERATE_TXT_JSON_PARSE_ERROR);
+  EXPECT_TRUE(ctx->has_error());
+  EXPECT_EQ(string(ctx->error_msg(), json_parse_error_len),
+      AiFunctions::AI_GENERATE_TXT_COMMON_ERROR_PREFIX
+          + AiFunctions::AI_GENERATE_TXT_JSON_PARSE_ERROR);
+  EXPECT_EQ(result.ptr, nullptr);
+  RecreateUdfTestContext(str_desc, v, nullptr, &pool, &ctx);
   // valid json results in overriding existing params ('model'), and adding 
new parms
   // like 'temperature' and 'stop'.
   StringVal valid_json_params(
@@ -11415,22 +11437,34 @@ TEST_P(ExprTest, AiFunctionsTest) {
   result = AiFunctions::AiGenerateTextInternal<false, 
AiFunctions::AI_PLATFORM::OPEN_AI>(
       ctx, openai_endpoint, prompt, model, jceks_secret, 
forbidden_msg_override,
       impala_options, dry_run);
-  EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
-      AiFunctions::AI_GENERATE_TXT_MSG_OVERRIDE_FORBIDDEN_ERROR);
+  EXPECT_TRUE(ctx->has_error());
+  EXPECT_EQ(string(ctx->error_msg(), override_forbidden_error_len),
+      AiFunctions::AI_GENERATE_TXT_COMMON_ERROR_PREFIX
+          + AiFunctions::AI_GENERATE_TXT_MSG_OVERRIDE_FORBIDDEN_ERROR);
+  EXPECT_EQ(result.ptr, nullptr);
+  RecreateUdfTestContext(str_desc, v, nullptr, &pool, &ctx);
   // 'n != 1' cannot be overriden as additional params
   StringVal forbidden_n_value("{\"n\": 2}");
   result = AiFunctions::AiGenerateTextInternal<false, 
AiFunctions::AI_PLATFORM::OPEN_AI>(
       ctx, openai_endpoint, prompt, model, jceks_secret, forbidden_n_value,
       impala_options, dry_run);
-  EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
-      AiFunctions::AI_GENERATE_TXT_N_OVERRIDE_FORBIDDEN_ERROR);
+  EXPECT_TRUE(ctx->has_error());
+  EXPECT_EQ(string(ctx->error_msg(), n_override_forbidden_error_len),
+      AiFunctions::AI_GENERATE_TXT_COMMON_ERROR_PREFIX
+          + AiFunctions::AI_GENERATE_TXT_N_OVERRIDE_FORBIDDEN_ERROR);
+  EXPECT_EQ(result.ptr, nullptr);
+  RecreateUdfTestContext(str_desc, v, nullptr, &pool, &ctx);
   // non integer value of 'n' cannot be overriden as additional params
   StringVal forbidden_n_type("{\"n\": \"1\"}");
   result = AiFunctions::AiGenerateTextInternal<false, 
AiFunctions::AI_PLATFORM::OPEN_AI>(
       ctx, openai_endpoint, prompt, model, jceks_secret, forbidden_n_type, 
impala_options,
       dry_run);
-  EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
-      AiFunctions::AI_GENERATE_TXT_N_OVERRIDE_FORBIDDEN_ERROR);
+  EXPECT_TRUE(ctx->has_error());
+  EXPECT_EQ(string(ctx->error_msg(), n_override_forbidden_error_len),
+      AiFunctions::AI_GENERATE_TXT_COMMON_ERROR_PREFIX
+          + AiFunctions::AI_GENERATE_TXT_N_OVERRIDE_FORBIDDEN_ERROR);
+  EXPECT_EQ(result.ptr, nullptr);
+  RecreateUdfTestContext(str_desc, v, nullptr, &pool, &ctx);
   // accept 'n=1' override as additional params
   StringVal allowed_n_override("{\"n\": 1}");
   result = AiFunctions::AiGenerateTextInternal<false, 
AiFunctions::AI_PLATFORM::OPEN_AI>(
@@ -11493,8 +11527,12 @@ TEST_P(ExprTest, AiFunctionsTest) {
   result = AiFunctions::AiGenerateTextInternal<false, 
AiFunctions::AI_PLATFORM::OPEN_AI>(
       ctx, openai_endpoint, prompt, model, jceks_secret, json_params,
       impala_options_mal_formatted, dry_run);
-  EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
-      AiFunctions::AI_GENERATE_TXT_JSON_PARSE_ERROR);
+  EXPECT_TRUE(ctx->has_error());
+  EXPECT_EQ(string(ctx->error_msg(), json_parse_error_len),
+      AiFunctions::AI_GENERATE_TXT_COMMON_ERROR_PREFIX
+          + AiFunctions::AI_GENERATE_TXT_JSON_PARSE_ERROR);
+  EXPECT_EQ(result.ptr, nullptr);
+  RecreateUdfTestContext(str_desc, v, nullptr, &pool, &ctx);
 
   // Test Impala options with payload exceeding 5MB.
   string large_string(5 * 1024 * 1024 + 1, 'A');
@@ -11503,8 +11541,12 @@ TEST_P(ExprTest, AiFunctionsTest) {
   result = AiFunctions::AiGenerateTextInternal<false, 
AiFunctions::AI_PLATFORM::OPEN_AI>(
       ctx, openai_endpoint, prompt, model, jceks_secret, json_params, 
impala_options_long,
       dry_run);
-  EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
-      AiFunctions::AI_GENERATE_TXT_JSON_PARSE_ERROR);
+  EXPECT_TRUE(ctx->has_error());
+  EXPECT_EQ(string(ctx->error_msg(), json_parse_error_len),
+      AiFunctions::AI_GENERATE_TXT_COMMON_ERROR_PREFIX
+          + AiFunctions::AI_GENERATE_TXT_JSON_PARSE_ERROR);
+  EXPECT_EQ(result.ptr, nullptr);
+  RecreateUdfTestContext(str_desc, v, nullptr, &pool, &ctx);
 
   // Test PLAIN credential type.
   StringVal plain_token("test_token");
diff --git a/fe/src/main/java/org/apache/impala/service/JniFrontend.java 
b/fe/src/main/java/org/apache/impala/service/JniFrontend.java
index c737d2f17..1a797d162 100644
--- a/fe/src/main/java/org/apache/impala/service/JniFrontend.java
+++ b/fe/src/main/java/org/apache/impala/service/JniFrontend.java
@@ -134,8 +134,8 @@ public class JniFrontend {
   private final static TBinaryProtocol.Factory protocolFactory_ =
       new TBinaryProtocol.Factory();
   private final Frontend frontend_;
-  public final static String KEYSTORE_ERROR_MSG = "Failed to get password 
from" +
-      "keystore, error: invalid key '%s' or password doesn't exist";
+  public final static String KEYSTORE_ERROR_MSG = "Failed to get password from 
"
+      + "keystore, error: invalid key '%s' or password doesn't exist";
 
   /**
    * Create a new instance of the Jni Frontend.
diff --git 
a/testdata/workloads/functional-query/queries/QueryTest/ai_generate_text_exprs.test
 
b/testdata/workloads/functional-query/queries/QueryTest/ai_generate_text_exprs.test
new file mode 100644
index 000000000..654e5cf48
--- /dev/null
+++ 
b/testdata/workloads/functional-query/queries/QueryTest/ai_generate_text_exprs.test
@@ -0,0 +1,79 @@
+====
+---- QUERY
+# Incorrect password.
+select ai_generate_text('https://api.openai.com/v1/chat/completions', '', '',
+'wrong_password', '', '');
+---- RESULTS
+---- CATCH
+row_regex:.* AI Generate Text Error:.*Failed to get password from keystore.*
+====
+---- QUERY
+# Missing endpoint.
+select ai_generate_text('', 'prompt', '', '', '', '');
+---- RESULTS
+---- CATCH
+row_regex:.* AI Generate Text Error:.*Network error: curl error.*
+====
+---- QUERY
+# Invalid JSON format in impala options.
+select ai_generate_text('https://api.openai.com/v1/chat/completions', 'prompt',
+'gpt-4', '', '', '{\"wrong_format:\"random_val\"}');
+---- RESULTS
+---- CATCH
+row_regex:.* AI Generate Text Error:.*Invalid Json: Error parsing impala 
options.*
+====
+---- QUERY
+# Invalid override.
+select ai_generate_text('https://api.openai.com/v1/chat/completions', 'prompt',
+'gpt-4', '', '{\"messages\": [{\"role\":\"system\",\"content\":\"howdy!\"}]}', 
'');
+---- RESULTS
+---- CATCH
+row_regex:.* AI Generate Text Error:.*Invalid override, 'messages' cannot be
+ overriden.*
+====
+---- QUERY
+# Invalid override.
+select ai_generate_text('https://api.openai.com/v1/chat/completions', 'prompt',
+'gpt-4', '', '{\"n\": 2}', '');
+---- RESULTS
+---- CATCH
+row_regex:.* AI Generate Text Error:.*Invalid override, 'n' must be of integer
+ type and have value 1.*
+====
+---- QUERY
+# Invalid override.
+select ai_generate_text('https://api.openai.com/v1/chat/completions', 'prompt',
+'gpt-4', '', '{\"n\": \"1\"}', '');
+---- RESULTS
+---- CATCH
+row_regex:.* AI Generate Text Error:.*Invalid override, 'n' must be of integer
+ type and have value 1.*
+====
+---- QUERY
+# Invalid JSON format in impala options.
+select ai_generate_text('https://api.openai.com/v1/chat/completions', 'prompt',
+'gpt-4', '', '', '{\"payload\":\"testpayload\", 
malformatted_key:\"malformatted_content}');
+---- RESULTS
+---- CATCH
+row_regex:.* AI Generate Text Error:.*Invalid Json: Error parsing impala 
options.*
+====
+---- QUERY
+select r.r_reason_desc, s.sr_return_amt
+FROM tpcds_parquet.store_returns s, tpcds_parquet.reason r
+WHERE s.sr_reason_sk=r.r_reason_sk AND s.sr_return_amt > 10000 AND
+ai_generate_text_default(CONCAT("Categorize the return reason as 'damaged',
+'not needed', 'expensive', 'incorrect order' and 'other': ", r.r_reason_desc))
+NOT IN ('other', 'not needed');
+---- CATCH
+row_regex:.* AI Generate Text Error:.*Network error: curl error.*
+====
+---- QUERY
+select r.r_reason_desc, s.sr_return_amt
+FROM tpcds_parquet.store_returns s, tpcds_parquet.reason r
+WHERE s.sr_reason_sk=r.r_reason_sk AND s.sr_return_amt > 10000 AND
+ai_generate_text("", CONCAT("Categorize the return reason as 'damaged',
+'not needed', 'expensive', 'incorrect order' and 'other': ", 
r.r_reason_desc),"","","","")
+NOT IN ('other', 'not needed');
+---- CATCH
+row_regex:.* AI Generate Text Error:.*Network error: curl error.*
+====
diff --git a/testdata/workloads/functional-query/queries/QueryTest/exprs.test 
b/testdata/workloads/functional-query/queries/QueryTest/exprs.test
index 3e49dc776..b79067ede 100644
--- a/testdata/workloads/functional-query/queries/QueryTest/exprs.test
+++ b/testdata/workloads/functional-query/queries/QueryTest/exprs.test
@@ -3390,23 +3390,3 @@ least(cast(19.44 as decimal(4,2)), cast(18.3 as 
decimal(3,1)));
 ---- TYPES
 DECIMAL,DECIMAL,DECIMAL,DECIMAL
 ====
----- QUERY
-select r.r_reason_desc, s.sr_return_amt
-FROM tpcds_parquet.store_returns s, tpcds_parquet.reason r
-WHERE s.sr_reason_sk=r.r_reason_sk AND s.sr_return_amt > 10000 AND
-ai_generate_text_default(CONCAT("Categorize the return reason as 'damaged',
-'not needed', 'expensive', 'incorrect order' and 'other': ", r.r_reason_desc))
-NOT IN ('other', 'not needed');
----- TYPES
-string, INT
-====
----- QUERY
-select r.r_reason_desc, s.sr_return_amt
-FROM tpcds_parquet.store_returns s, tpcds_parquet.reason r
-WHERE s.sr_reason_sk=r.r_reason_sk AND s.sr_return_amt > 10000 AND
-ai_generate_text("", CONCAT("Categorize the return reason as 'damaged',
-'not needed', 'expensive', 'incorrect order' and 'other': ", 
r.r_reason_desc),"","","","")
-NOT IN ('other', 'not needed');
----- TYPES
-string, INT
-====
diff --git a/tests/custom_cluster/test_ai_generate_text.py 
b/tests/custom_cluster/test_ai_generate_text.py
new file mode 100644
index 000000000..5b8d82504
--- /dev/null
+++ b/tests/custom_cluster/test_ai_generate_text.py
@@ -0,0 +1,67 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# Tests for query expiration.
+
+from __future__ import absolute_import, division, print_function
+import pytest
+import re
+
+from tests.common.custom_cluster_test_suite import CustomClusterTestSuite
+
+
+class TestAIGenerateText(CustomClusterTestSuite):
+  @classmethod
+  def get_workload(cls):
+    return 'functional-query'
+
+  @classmethod
+  def setup_class(cls):
+    if cls.exploration_strategy() != 'exhaustive':
+      pytest.skip('runs only in exhaustive')
+    super(TestAIGenerateText, cls).setup_class()
+
+  # Using ai_generate_text_default
+  ai_generate_text_default_query = """
+      select ai_generate_text_default("test")
+      """
+
+  AI_GENERATE_COMMON_ERR_PREFIX = "AI Generate Text Error:"
+  AI_CURL_NETWORK_ERR = "Network error: curl error"
+
+  @pytest.mark.execute_serially
+  def test_inaccessible_site(self):
+    self._start_impala_cluster([
+      '--impalad_args=--ai_additional_platforms="bad.site" '
+      '--ai_endpoint="https://bad.site";'])
+    impalad = self.cluster.get_any_impalad()
+    client = impalad.service.create_beeswax_client()
+    err = self.execute_query_expect_failure(client, 
self.ai_generate_text_default_query)
+    assert re.search(re.escape(self.AI_GENERATE_COMMON_ERR_PREFIX), str(err))
+    assert re.search(re.escape(self.AI_CURL_NETWORK_ERR), str(err))
+
+  @pytest.mark.execute_serially
+  def test_emptyjceks(self):
+    self._start_impala_cluster([
+      '--impalad_args=--ai_model="gpt-4" '
+      '--ai_endpoint="https://api.openai.com/v1/chat/completions"; '
+      '--ai_api_key_jceks_secret=""'])
+    impalad = self.cluster.get_any_impalad()
+    client = impalad.service.create_beeswax_client()
+    err = self.execute_query_expect_failure(client, 
self.ai_generate_text_default_query)
+    assert re.search(re.escape(self.AI_GENERATE_COMMON_ERR_PREFIX), str(err))
+    assert re.search(re.escape(self.AI_CURL_NETWORK_ERR), str(err))
diff --git a/tests/query_test/test_exprs.py b/tests/query_test/test_exprs.py
index 0852ac67b..46946d2f9 100644
--- a/tests/query_test/test_exprs.py
+++ b/tests/query_test/test_exprs.py
@@ -134,6 +134,12 @@ class TestExprs(ImpalaTestSuite):
       assert "not supported by OpenSSL" in str(e)
       return False
 
+  def test_ai_generate_text_exprs(self, vector):
+    table_format = vector.get_value('table_format')
+    if table_format.file_format != 'parquet':
+      pytest.skip()
+    self.run_test_case('QueryTest/ai_generate_text_exprs', vector)
+
 
 # Tests very deep expression trees and expressions with many children. Impala 
defines
 # a 'safe' upper bound on the expr depth and the number of expr children in the

Reply via email to