This is an automated email from the ASF dual-hosted git repository. stigahuang pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/impala.git
commit 1c6ff5f98d010ad81601aaa60662dfb7bec3b5d9 Author: Yida Wu <[email protected]> AuthorDate: Mon Mar 3 11:08:53 2025 -0800 IMPALA-13812: Fail query for certain errors related to AI functions The ai_generate_text() and ai_generate_text_default() functions return error message as a result (string) which could be misleading in some cases. This patch fixes this issue by setting the error in the context as a udf error, causing the query to fail in cases of configuration related errors or http errors when accessing the AI endpoint. Tests: Ran core tests. Added custom testcase TestAIGenerateText for failure cases with ai_generate_text_default(). Added testcase TestExprs.test_ai_generate_text_exprs for failure cases with ai_generate_text(). Change-Id: I639e48e64d62f7990cf9a3c35a59a0ee3a2c64e0 Reviewed-on: http://gerrit.cloudera.org:8080/22588 Reviewed-by: Yida Wu <[email protected]> Tested-by: Impala Public Jenkins <[email protected]> --- be/src/exprs/ai-functions-ir.cc | 2 + be/src/exprs/ai-functions.cc | 63 +++++++++++------ be/src/exprs/ai-functions.h | 1 + be/src/exprs/expr-test.cc | 66 ++++++++++++++---- .../org/apache/impala/service/JniFrontend.java | 4 +- .../queries/QueryTest/ai_generate_text_exprs.test | 79 ++++++++++++++++++++++ .../functional-query/queries/QueryTest/exprs.test | 20 ------ tests/custom_cluster/test_ai_generate_text.py | 67 ++++++++++++++++++ tests/query_test/test_exprs.py | 6 ++ 9 files changed, 254 insertions(+), 54 deletions(-) diff --git a/be/src/exprs/ai-functions-ir.cc b/be/src/exprs/ai-functions-ir.cc index 2b36e268b..6f4772caf 100644 --- a/be/src/exprs/ai-functions-ir.cc +++ b/be/src/exprs/ai-functions-ir.cc @@ -51,6 +51,8 @@ const string AiFunctions::AI_GENERATE_TXT_MSG_OVERRIDE_FORBIDDEN_ERROR = "Invalid override, 'messages' cannot be overriden"; const string AiFunctions::AI_GENERATE_TXT_N_OVERRIDE_FORBIDDEN_ERROR = "Invalid override, 'n' must be of integer type and have value 1"; +const string AiFunctions::AI_GENERATE_TXT_COMMON_ERROR_PREFIX = + "AI Generate Text Error: "; string AiFunctions::ai_api_key_; const char* AiFunctions::OPEN_AI_REQUEST_FIELD_CONTENT_TYPE_HEADER = "Content-Type: application/json"; diff --git a/be/src/exprs/ai-functions.cc b/be/src/exprs/ai-functions.cc index 2e7dd6617..1ef2b079d 100644 --- a/be/src/exprs/ai-functions.cc +++ b/be/src/exprs/ai-functions.cc @@ -58,12 +58,20 @@ TAG_FLAG(ai_api_key_jceks_secret, sensitive); namespace impala { -#define RETURN_STRINGVAL_IF_ERROR(ctx, stmt) \ - do { \ - const ::impala::Status& _status = (stmt); \ - if (UNLIKELY(!_status.ok())) { \ - return copyErrorMessage(ctx, _status.msg().msg()); \ - } \ +// Set an error message in the context, causing the query to fail. +#define SET_ERROR(ctx, status_str, prefix) \ + do { \ + (ctx)->SetError((prefix + status_str).c_str()); \ + } while (false) + +// Check the status and return an error if it fails. +#define RETURN_STRINGVAL_IF_ERROR(ctx, stmt) \ + do { \ + const ::impala::Status& _status = (stmt); \ + if (UNLIKELY(!_status.ok())) { \ + SET_ERROR(ctx, _status.msg().msg(), AI_GENERATE_TXT_COMMON_ERROR_PREFIX); \ + return StringVal::null(); \ + } \ } while (false) // Impala Ai Functions Options Constants. @@ -181,8 +189,11 @@ StringVal AiFunctions::AiGenerateTextInternal(FunctionContext* ctx, try { ParseImpalaOptions(impala_options, impala_options_document, ai_options); } catch (const std::runtime_error& e) { - LOG(WARNING) << AI_GENERATE_TXT_JSON_PARSE_ERROR << ": " << e.what(); - return StringVal(AI_GENERATE_TXT_JSON_PARSE_ERROR.c_str()); + std::stringstream ss; + ss << AI_GENERATE_TXT_JSON_PARSE_ERROR << ": " << e.what(); + LOG(WARNING) << ss.str(); + const Status err_status(ss.str()); + RETURN_STRINGVAL_IF_ERROR(ctx, err_status); } } @@ -238,6 +249,9 @@ StringVal AiFunctions::AiGenerateTextInternal(FunctionContext* ctx, Value message(rapidjson::kObjectType); message.AddMember("role", "user", payload_allocator); if (prompt.ptr == nullptr || prompt.len == 0) { + // Return a string with the invalid prompt error message instead of failing + // the query, as the issue may be with the row rather than the configuration + // or query. This behavior might be reconsidered later. return StringVal(AI_GENERATE_TXT_INVALID_PROMPT_ERROR.c_str()); } message.AddMember("content", @@ -253,26 +267,32 @@ StringVal AiFunctions::AiGenerateTextInternal(FunctionContext* ctx, if (!fastpath && platform_params.ptr != nullptr && platform_params.len != 0) { overrides.Parse(reinterpret_cast<char*>(platform_params.ptr), platform_params.len); if (overrides.HasParseError()) { - LOG(WARNING) << AI_GENERATE_TXT_JSON_PARSE_ERROR << ": error code " - << overrides.GetParseError() << ", offset input " - << overrides.GetErrorOffset(); - return StringVal(AI_GENERATE_TXT_JSON_PARSE_ERROR.c_str()); + std::stringstream ss; + ss << AI_GENERATE_TXT_JSON_PARSE_ERROR << ": error code " + << overrides.GetParseError() << ", offset input " + << overrides.GetErrorOffset(); + LOG(WARNING) << ss.str(); + const Status err_status(ss.str()); + RETURN_STRINGVAL_IF_ERROR(ctx, err_status); } for (auto& m : overrides.GetObject()) { if (payload.HasMember(m.name.GetString())) { if (m.name == "messages") { - LOG(WARNING) - << AI_GENERATE_TXT_JSON_PARSE_ERROR - << ": 'messages' is constructed from 'prompt', cannot be overridden"; - return StringVal(AI_GENERATE_TXT_MSG_OVERRIDE_FORBIDDEN_ERROR.c_str()); + const string error_msg = AI_GENERATE_TXT_MSG_OVERRIDE_FORBIDDEN_ERROR + + ": 'messages' is constructed from 'prompt', cannot be overridden"; + LOG(WARNING) << error_msg; + const Status err_status(error_msg); + RETURN_STRINGVAL_IF_ERROR(ctx, err_status); } else { payload[m.name.GetString()] = m.value; } } else { if ((m.name == "n") && !(m.value.IsInt() && m.value.GetInt() == 1)) { - LOG(WARNING) << AI_GENERATE_TXT_JSON_PARSE_ERROR - << ": 'n' must be of integer type and have value 1"; - return StringVal(AI_GENERATE_TXT_N_OVERRIDE_FORBIDDEN_ERROR.c_str()); + const string error_msg = AI_GENERATE_TXT_N_OVERRIDE_FORBIDDEN_ERROR + + ": 'n' must be of integer type and have value 1"; + LOG(WARNING) << error_msg; + const Status err_status(error_msg); + RETURN_STRINGVAL_IF_ERROR(ctx, err_status); } payload.AddMember(m.name, m.value, payload_allocator); } @@ -312,7 +332,10 @@ StringVal AiFunctions::AiGenerateTextInternal(FunctionContext* ctx, status = curl.PostToURL(endpoint_str, payload_str, &resp, headers); } VLOG(2) << "AI Generate Text: \noriginal response: " << resp.ToString(); - if (UNLIKELY(!status.ok())) return copyErrorMessage(ctx, status.ToString()); + if (UNLIKELY(!status.ok())) { + SET_ERROR(ctx, status.ToString(), AI_GENERATE_TXT_COMMON_ERROR_PREFIX); + return StringVal::null(); + } // Parse the JSON response string std::string response = AiGenerateTextParseOpenAiResponse( std::string_view(reinterpret_cast<char*>(resp.data()), resp.size())); diff --git a/be/src/exprs/ai-functions.h b/be/src/exprs/ai-functions.h index d72c3b54c..85a6e1b48 100644 --- a/be/src/exprs/ai-functions.h +++ b/be/src/exprs/ai-functions.h @@ -36,6 +36,7 @@ class AiFunctions { static const string AI_GENERATE_TXT_INVALID_PROMPT_ERROR; static const string AI_GENERATE_TXT_MSG_OVERRIDE_FORBIDDEN_ERROR; static const string AI_GENERATE_TXT_N_OVERRIDE_FORBIDDEN_ERROR; + static const string AI_GENERATE_TXT_COMMON_ERROR_PREFIX; static const char* OPEN_AI_REQUEST_FIELD_CONTENT_TYPE_HEADER; static const char* OPEN_AI_REQUEST_AUTH_HEADER; static const char* AZURE_OPEN_AI_REQUEST_AUTH_HEADER; diff --git a/be/src/exprs/expr-test.cc b/be/src/exprs/expr-test.cc index 0662354c3..e0a117a96 100644 --- a/be/src/exprs/expr-test.cc +++ b/be/src/exprs/expr-test.cc @@ -1259,6 +1259,16 @@ class ExprTest : public testing::TestWithParam<std::tuple<bool, bool>> { UdfTestHarness::CreateTestContext(return_type, arg_types, state, pool)); } + // Helper function to close context then create a new one. + void RecreateUdfTestContext(const FunctionContext::TypeDesc& return_type, + const std::vector<FunctionContext::TypeDesc>& arg_types, RuntimeState* state, + MemPool* pool, FunctionContext** old_ctx) { + ASSERT_TRUE(old_ctx != nullptr && *old_ctx != nullptr); + UdfTestHarness::CloseContext(*old_ctx); + *old_ctx = CreateUdfTestContext(return_type, arg_types, state, pool); + ASSERT_TRUE(*old_ctx != nullptr); + } + void TestBytes(); }; @@ -11388,14 +11398,26 @@ TEST_P(ExprTest, AiFunctionsTest) { EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len), AiFunctions::AI_GENERATE_TXT_INVALID_PROMPT_ERROR); + size_t json_parse_error_len = AiFunctions::AI_GENERATE_TXT_JSON_PARSE_ERROR.size() + + AiFunctions::AI_GENERATE_TXT_COMMON_ERROR_PREFIX.size(); + size_t override_forbidden_error_len = + AiFunctions::AI_GENERATE_TXT_MSG_OVERRIDE_FORBIDDEN_ERROR.size() + + AiFunctions::AI_GENERATE_TXT_COMMON_ERROR_PREFIX.size(); + size_t n_override_forbidden_error_len = + AiFunctions::AI_GENERATE_TXT_N_OVERRIDE_FORBIDDEN_ERROR.size() + + AiFunctions::AI_GENERATE_TXT_COMMON_ERROR_PREFIX.size(); // Test override/additional params // invalid json results in error. StringVal invalid_json_params("{\"temperature\": 0.49, \"stop\": [\"*\",::,]}"); result = AiFunctions::AiGenerateTextInternal<false, AiFunctions::AI_PLATFORM::OPEN_AI>( ctx, openai_endpoint, prompt, model, jceks_secret, invalid_json_params, impala_options, dry_run); - EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len), - AiFunctions::AI_GENERATE_TXT_JSON_PARSE_ERROR); + EXPECT_TRUE(ctx->has_error()); + EXPECT_EQ(string(ctx->error_msg(), json_parse_error_len), + AiFunctions::AI_GENERATE_TXT_COMMON_ERROR_PREFIX + + AiFunctions::AI_GENERATE_TXT_JSON_PARSE_ERROR); + EXPECT_EQ(result.ptr, nullptr); + RecreateUdfTestContext(str_desc, v, nullptr, &pool, &ctx); // valid json results in overriding existing params ('model'), and adding new parms // like 'temperature' and 'stop'. StringVal valid_json_params( @@ -11415,22 +11437,34 @@ TEST_P(ExprTest, AiFunctionsTest) { result = AiFunctions::AiGenerateTextInternal<false, AiFunctions::AI_PLATFORM::OPEN_AI>( ctx, openai_endpoint, prompt, model, jceks_secret, forbidden_msg_override, impala_options, dry_run); - EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len), - AiFunctions::AI_GENERATE_TXT_MSG_OVERRIDE_FORBIDDEN_ERROR); + EXPECT_TRUE(ctx->has_error()); + EXPECT_EQ(string(ctx->error_msg(), override_forbidden_error_len), + AiFunctions::AI_GENERATE_TXT_COMMON_ERROR_PREFIX + + AiFunctions::AI_GENERATE_TXT_MSG_OVERRIDE_FORBIDDEN_ERROR); + EXPECT_EQ(result.ptr, nullptr); + RecreateUdfTestContext(str_desc, v, nullptr, &pool, &ctx); // 'n != 1' cannot be overriden as additional params StringVal forbidden_n_value("{\"n\": 2}"); result = AiFunctions::AiGenerateTextInternal<false, AiFunctions::AI_PLATFORM::OPEN_AI>( ctx, openai_endpoint, prompt, model, jceks_secret, forbidden_n_value, impala_options, dry_run); - EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len), - AiFunctions::AI_GENERATE_TXT_N_OVERRIDE_FORBIDDEN_ERROR); + EXPECT_TRUE(ctx->has_error()); + EXPECT_EQ(string(ctx->error_msg(), n_override_forbidden_error_len), + AiFunctions::AI_GENERATE_TXT_COMMON_ERROR_PREFIX + + AiFunctions::AI_GENERATE_TXT_N_OVERRIDE_FORBIDDEN_ERROR); + EXPECT_EQ(result.ptr, nullptr); + RecreateUdfTestContext(str_desc, v, nullptr, &pool, &ctx); // non integer value of 'n' cannot be overriden as additional params StringVal forbidden_n_type("{\"n\": \"1\"}"); result = AiFunctions::AiGenerateTextInternal<false, AiFunctions::AI_PLATFORM::OPEN_AI>( ctx, openai_endpoint, prompt, model, jceks_secret, forbidden_n_type, impala_options, dry_run); - EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len), - AiFunctions::AI_GENERATE_TXT_N_OVERRIDE_FORBIDDEN_ERROR); + EXPECT_TRUE(ctx->has_error()); + EXPECT_EQ(string(ctx->error_msg(), n_override_forbidden_error_len), + AiFunctions::AI_GENERATE_TXT_COMMON_ERROR_PREFIX + + AiFunctions::AI_GENERATE_TXT_N_OVERRIDE_FORBIDDEN_ERROR); + EXPECT_EQ(result.ptr, nullptr); + RecreateUdfTestContext(str_desc, v, nullptr, &pool, &ctx); // accept 'n=1' override as additional params StringVal allowed_n_override("{\"n\": 1}"); result = AiFunctions::AiGenerateTextInternal<false, AiFunctions::AI_PLATFORM::OPEN_AI>( @@ -11493,8 +11527,12 @@ TEST_P(ExprTest, AiFunctionsTest) { result = AiFunctions::AiGenerateTextInternal<false, AiFunctions::AI_PLATFORM::OPEN_AI>( ctx, openai_endpoint, prompt, model, jceks_secret, json_params, impala_options_mal_formatted, dry_run); - EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len), - AiFunctions::AI_GENERATE_TXT_JSON_PARSE_ERROR); + EXPECT_TRUE(ctx->has_error()); + EXPECT_EQ(string(ctx->error_msg(), json_parse_error_len), + AiFunctions::AI_GENERATE_TXT_COMMON_ERROR_PREFIX + + AiFunctions::AI_GENERATE_TXT_JSON_PARSE_ERROR); + EXPECT_EQ(result.ptr, nullptr); + RecreateUdfTestContext(str_desc, v, nullptr, &pool, &ctx); // Test Impala options with payload exceeding 5MB. string large_string(5 * 1024 * 1024 + 1, 'A'); @@ -11503,8 +11541,12 @@ TEST_P(ExprTest, AiFunctionsTest) { result = AiFunctions::AiGenerateTextInternal<false, AiFunctions::AI_PLATFORM::OPEN_AI>( ctx, openai_endpoint, prompt, model, jceks_secret, json_params, impala_options_long, dry_run); - EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len), - AiFunctions::AI_GENERATE_TXT_JSON_PARSE_ERROR); + EXPECT_TRUE(ctx->has_error()); + EXPECT_EQ(string(ctx->error_msg(), json_parse_error_len), + AiFunctions::AI_GENERATE_TXT_COMMON_ERROR_PREFIX + + AiFunctions::AI_GENERATE_TXT_JSON_PARSE_ERROR); + EXPECT_EQ(result.ptr, nullptr); + RecreateUdfTestContext(str_desc, v, nullptr, &pool, &ctx); // Test PLAIN credential type. StringVal plain_token("test_token"); diff --git a/fe/src/main/java/org/apache/impala/service/JniFrontend.java b/fe/src/main/java/org/apache/impala/service/JniFrontend.java index c737d2f17..1a797d162 100644 --- a/fe/src/main/java/org/apache/impala/service/JniFrontend.java +++ b/fe/src/main/java/org/apache/impala/service/JniFrontend.java @@ -134,8 +134,8 @@ public class JniFrontend { private final static TBinaryProtocol.Factory protocolFactory_ = new TBinaryProtocol.Factory(); private final Frontend frontend_; - public final static String KEYSTORE_ERROR_MSG = "Failed to get password from" + - "keystore, error: invalid key '%s' or password doesn't exist"; + public final static String KEYSTORE_ERROR_MSG = "Failed to get password from " + + "keystore, error: invalid key '%s' or password doesn't exist"; /** * Create a new instance of the Jni Frontend. diff --git a/testdata/workloads/functional-query/queries/QueryTest/ai_generate_text_exprs.test b/testdata/workloads/functional-query/queries/QueryTest/ai_generate_text_exprs.test new file mode 100644 index 000000000..654e5cf48 --- /dev/null +++ b/testdata/workloads/functional-query/queries/QueryTest/ai_generate_text_exprs.test @@ -0,0 +1,79 @@ +==== +---- QUERY +# Incorrect password. +select ai_generate_text('https://api.openai.com/v1/chat/completions', '', '', +'wrong_password', '', ''); +---- RESULTS +---- CATCH +row_regex:.* AI Generate Text Error:.*Failed to get password from keystore.* +==== +---- QUERY +# Missing endpoint. +select ai_generate_text('', 'prompt', '', '', '', ''); +---- RESULTS +---- CATCH +row_regex:.* AI Generate Text Error:.*Network error: curl error.* +==== +---- QUERY +# Invalid JSON format in impala options. +select ai_generate_text('https://api.openai.com/v1/chat/completions', 'prompt', +'gpt-4', '', '', '{\"wrong_format:\"random_val\"}'); +---- RESULTS +---- CATCH +row_regex:.* AI Generate Text Error:.*Invalid Json: Error parsing impala options.* +==== +---- QUERY +# Invalid override. +select ai_generate_text('https://api.openai.com/v1/chat/completions', 'prompt', +'gpt-4', '', '{\"messages\": [{\"role\":\"system\",\"content\":\"howdy!\"}]}', ''); +---- RESULTS +---- CATCH +row_regex:.* AI Generate Text Error:.*Invalid override, 'messages' cannot be + overriden.* +==== +---- QUERY +# Invalid override. +select ai_generate_text('https://api.openai.com/v1/chat/completions', 'prompt', +'gpt-4', '', '{\"n\": 2}', ''); +---- RESULTS +---- CATCH +row_regex:.* AI Generate Text Error:.*Invalid override, 'n' must be of integer + type and have value 1.* +==== +---- QUERY +# Invalid override. +select ai_generate_text('https://api.openai.com/v1/chat/completions', 'prompt', +'gpt-4', '', '{\"n\": \"1\"}', ''); +---- RESULTS +---- CATCH +row_regex:.* AI Generate Text Error:.*Invalid override, 'n' must be of integer + type and have value 1.* +==== +---- QUERY +# Invalid JSON format in impala options. +select ai_generate_text('https://api.openai.com/v1/chat/completions', 'prompt', +'gpt-4', '', '', '{\"payload\":\"testpayload\", malformatted_key:\"malformatted_content}'); +---- RESULTS +---- CATCH +row_regex:.* AI Generate Text Error:.*Invalid Json: Error parsing impala options.* +==== +---- QUERY +select r.r_reason_desc, s.sr_return_amt +FROM tpcds_parquet.store_returns s, tpcds_parquet.reason r +WHERE s.sr_reason_sk=r.r_reason_sk AND s.sr_return_amt > 10000 AND +ai_generate_text_default(CONCAT("Categorize the return reason as 'damaged', +'not needed', 'expensive', 'incorrect order' and 'other': ", r.r_reason_desc)) +NOT IN ('other', 'not needed'); +---- CATCH +row_regex:.* AI Generate Text Error:.*Network error: curl error.* +==== +---- QUERY +select r.r_reason_desc, s.sr_return_amt +FROM tpcds_parquet.store_returns s, tpcds_parquet.reason r +WHERE s.sr_reason_sk=r.r_reason_sk AND s.sr_return_amt > 10000 AND +ai_generate_text("", CONCAT("Categorize the return reason as 'damaged', +'not needed', 'expensive', 'incorrect order' and 'other': ", r.r_reason_desc),"","","","") +NOT IN ('other', 'not needed'); +---- CATCH +row_regex:.* AI Generate Text Error:.*Network error: curl error.* +==== diff --git a/testdata/workloads/functional-query/queries/QueryTest/exprs.test b/testdata/workloads/functional-query/queries/QueryTest/exprs.test index 3e49dc776..b79067ede 100644 --- a/testdata/workloads/functional-query/queries/QueryTest/exprs.test +++ b/testdata/workloads/functional-query/queries/QueryTest/exprs.test @@ -3390,23 +3390,3 @@ least(cast(19.44 as decimal(4,2)), cast(18.3 as decimal(3,1))); ---- TYPES DECIMAL,DECIMAL,DECIMAL,DECIMAL ==== ----- QUERY -select r.r_reason_desc, s.sr_return_amt -FROM tpcds_parquet.store_returns s, tpcds_parquet.reason r -WHERE s.sr_reason_sk=r.r_reason_sk AND s.sr_return_amt > 10000 AND -ai_generate_text_default(CONCAT("Categorize the return reason as 'damaged', -'not needed', 'expensive', 'incorrect order' and 'other': ", r.r_reason_desc)) -NOT IN ('other', 'not needed'); ----- TYPES -string, INT -==== ----- QUERY -select r.r_reason_desc, s.sr_return_amt -FROM tpcds_parquet.store_returns s, tpcds_parquet.reason r -WHERE s.sr_reason_sk=r.r_reason_sk AND s.sr_return_amt > 10000 AND -ai_generate_text("", CONCAT("Categorize the return reason as 'damaged', -'not needed', 'expensive', 'incorrect order' and 'other': ", r.r_reason_desc),"","","","") -NOT IN ('other', 'not needed'); ----- TYPES -string, INT -==== diff --git a/tests/custom_cluster/test_ai_generate_text.py b/tests/custom_cluster/test_ai_generate_text.py new file mode 100644 index 000000000..5b8d82504 --- /dev/null +++ b/tests/custom_cluster/test_ai_generate_text.py @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# Tests for query expiration. + +from __future__ import absolute_import, division, print_function +import pytest +import re + +from tests.common.custom_cluster_test_suite import CustomClusterTestSuite + + +class TestAIGenerateText(CustomClusterTestSuite): + @classmethod + def get_workload(cls): + return 'functional-query' + + @classmethod + def setup_class(cls): + if cls.exploration_strategy() != 'exhaustive': + pytest.skip('runs only in exhaustive') + super(TestAIGenerateText, cls).setup_class() + + # Using ai_generate_text_default + ai_generate_text_default_query = """ + select ai_generate_text_default("test") + """ + + AI_GENERATE_COMMON_ERR_PREFIX = "AI Generate Text Error:" + AI_CURL_NETWORK_ERR = "Network error: curl error" + + @pytest.mark.execute_serially + def test_inaccessible_site(self): + self._start_impala_cluster([ + '--impalad_args=--ai_additional_platforms="bad.site" ' + '--ai_endpoint="https://bad.site"']) + impalad = self.cluster.get_any_impalad() + client = impalad.service.create_beeswax_client() + err = self.execute_query_expect_failure(client, self.ai_generate_text_default_query) + assert re.search(re.escape(self.AI_GENERATE_COMMON_ERR_PREFIX), str(err)) + assert re.search(re.escape(self.AI_CURL_NETWORK_ERR), str(err)) + + @pytest.mark.execute_serially + def test_emptyjceks(self): + self._start_impala_cluster([ + '--impalad_args=--ai_model="gpt-4" ' + '--ai_endpoint="https://api.openai.com/v1/chat/completions" ' + '--ai_api_key_jceks_secret=""']) + impalad = self.cluster.get_any_impalad() + client = impalad.service.create_beeswax_client() + err = self.execute_query_expect_failure(client, self.ai_generate_text_default_query) + assert re.search(re.escape(self.AI_GENERATE_COMMON_ERR_PREFIX), str(err)) + assert re.search(re.escape(self.AI_CURL_NETWORK_ERR), str(err)) diff --git a/tests/query_test/test_exprs.py b/tests/query_test/test_exprs.py index 0852ac67b..46946d2f9 100644 --- a/tests/query_test/test_exprs.py +++ b/tests/query_test/test_exprs.py @@ -134,6 +134,12 @@ class TestExprs(ImpalaTestSuite): assert "not supported by OpenSSL" in str(e) return False + def test_ai_generate_text_exprs(self, vector): + table_format = vector.get_value('table_format') + if table_format.file_format != 'parquet': + pytest.skip() + self.run_test_case('QueryTest/ai_generate_text_exprs', vector) + # Tests very deep expression trees and expressions with many children. Impala defines # a 'safe' upper bound on the expr depth and the number of expr children in the
