This is an automated email from the ASF dual-hosted git repository.

arawat pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/impala.git

commit 7e8c79df53af44d2363c3cb938c4a43d5bb23a84
Author: Abhishek Rawat <[email protected]>
AuthorDate: Wed Feb 26 08:32:58 2025 -0800

    IMPALA-13792: Cross compile AI functions
    
    ai_generate_text() and ai_generate_text_default() are not
    cross-compiled and so these could lead to undefined symbols when
    codegen is enabled. This patch cross-compiles these functions.
    
    Testing:
    - Added e2e tests for ai_generate_text and ai_generate_text_default
    functions and these are run with codegen enabled/disabled.
    
    Change-Id: I454657d9f1345a36b269e6b837aaecf55a09add0
    Reviewed-on: http://gerrit.cloudera.org:8080/22552
    Reviewed-by: Impala Public Jenkins <[email protected]>
    Tested-by: Impala Public Jenkins <[email protected]>
---
 be/src/codegen/impala-ir.cc                        |  1 +
 be/src/exprs/CMakeLists.txt                        |  3 +-
 be/src/exprs/ai-functions-ir.cc                    | 40 +++++++--------
 .../{ai-functions.inline.h => ai-functions.cc}     | 59 +++++++++++++++++++---
 be/src/exprs/expr-test.cc                          |  2 +-
 .../functional-query/queries/QueryTest/exprs.test  | 20 ++++++++
 6 files changed, 94 insertions(+), 31 deletions(-)

diff --git a/be/src/codegen/impala-ir.cc b/be/src/codegen/impala-ir.cc
index 0155468ab..81248f8ba 100644
--- a/be/src/codegen/impala-ir.cc
+++ b/be/src/codegen/impala-ir.cc
@@ -40,6 +40,7 @@
 #include "exec/union-node-ir.cc"
 #include "exprs/agg-fn-evaluator-ir.cc"
 #include "exprs/aggregate-functions-ir.cc"
+#include "exprs/ai-functions-ir.cc"
 #include "exprs/bit-byte-functions-ir.cc"
 #include "exprs/cast-functions-ir.cc"
 #include "exprs/compound-predicates-ir.cc"
diff --git a/be/src/exprs/CMakeLists.txt b/be/src/exprs/CMakeLists.txt
index df5ef65f3..513ba4736 100644
--- a/be/src/exprs/CMakeLists.txt
+++ b/be/src/exprs/CMakeLists.txt
@@ -25,9 +25,9 @@ set(EXECUTABLE_OUTPUT_PATH 
"${BUILD_OUTPUT_ROOT_DIRECTORY}/exprs")
 set(MURMURHASH_SRC_DIR "${CMAKE_SOURCE_DIR}/be/src/thirdparty/murmurhash")
 
 add_library(ExprsIr
-  ai-functions-ir.cc
   agg-fn-evaluator-ir.cc
   aggregate-functions-ir.cc
+  ai-functions-ir.cc
   bit-byte-functions-ir.cc
   cast-functions-ir.cc
   compound-predicates-ir.cc
@@ -58,6 +58,7 @@ add_dependencies(ExprsIr gen-deps)
 add_library(Exprs
   agg-fn.cc
   agg-fn-evaluator.cc
+  ai-functions.cc
   anyval-util.cc
   case-expr.cc
   cast-format-expr.cc
diff --git a/be/src/exprs/ai-functions-ir.cc b/be/src/exprs/ai-functions-ir.cc
index d5906588a..a3a40459d 100644
--- a/be/src/exprs/ai-functions-ir.cc
+++ b/be/src/exprs/ai-functions-ir.cc
@@ -18,9 +18,14 @@
 // The functions in this file are specifically not cross-compiled to IR 
because there
 // is no signifcant performance benefit to be gained.
 
+#include <set>
+
 #include <boost/algorithm/string/trim.hpp>
+#include <gutil/strings/util.h>
+#include <rapidjson/document.h>
 
-#include "exprs/ai-functions.inline.h"
+#include "common/compiler-util.h"
+#include "exprs/ai-functions.h"
 
 using namespace impala_udf;
 using boost::algorithm::trim;
@@ -29,28 +34,8 @@ using std::istringstream;
 using std::set;
 using std::string_view;
 
-DEFINE_string(ai_endpoint, "https://api.openai.com/v1/chat/completions";,
-    "The default API endpoint for an external AI engine.");
-DEFINE_validator(ai_endpoint, [](const char* name, const string& endpoint) {
-  return (impala::AiFunctions::is_api_endpoint_valid(endpoint) &&
-      impala::AiFunctions::is_api_endpoint_supported(endpoint));
-});
-
-DEFINE_string(ai_model, "gpt-4", "The default AI model used by an external AI 
engine.");
-
-DEFINE_string(ai_api_key_jceks_secret, "",
-    "The jceks secret key used for extracting the api key from configured 
keystores. "
-    "'hadoop.security.credential.provider.path' in core-site must be 
configured to "
-    "include the keystore storing the corresponding secret.");
-
-DEFINE_string(ai_additional_platforms, "",
-    "A comma-separated list of additional platforms allowed for Impala to 
access via "
-    "the AI api, formatted as 'site1,site2'.");
-
-DEFINE_int32(ai_connection_timeout_s, 10,
-    "(Advanced) The time in seconds for connection timed out when 
communicating with an "
-    "external AI engine");
-TAG_FLAG(ai_api_key_jceks_secret, sensitive);
+DECLARE_string(ai_endpoint);
+DECLARE_string(ai_additional_platforms);
 
 namespace impala {
 
@@ -267,4 +252,13 @@ StringVal AiFunctions::AiGenerateTextDefault(
       NULL_STRINGVAL, NULL_STRINGVAL, NULL_STRINGVAL);
 }
 
+// Explicit template instantiations for AiGenerateTextHelper
+template StringVal AiFunctions::AiGenerateTextHelper<true>(FunctionContext*,
+    const StringVal&, const StringVal&, const StringVal&, const StringVal&,
+    const StringVal&, const StringVal&);
+
+template StringVal AiFunctions::AiGenerateTextHelper<false>(FunctionContext*,
+    const StringVal&, const StringVal&, const StringVal&, const StringVal&,
+    const StringVal&, const StringVal&);
+
 } // namespace impala
diff --git a/be/src/exprs/ai-functions.inline.h b/be/src/exprs/ai-functions.cc
similarity index 84%
rename from be/src/exprs/ai-functions.inline.h
rename to be/src/exprs/ai-functions.cc
index fd742ed2f..a9dadf844 100644
--- a/be/src/exprs/ai-functions.inline.h
+++ b/be/src/exprs/ai-functions.cc
@@ -15,8 +15,6 @@
 // specific language governing permissions and limitations
 // under the License.
 
-#pragma once
-
 #include <gutil/strings/util.h>
 #include <rapidjson/document.h>
 #include <rapidjson/error/en.h>
@@ -36,10 +34,28 @@
 using namespace rapidjson;
 using namespace impala_udf;
 
-DECLARE_string(ai_endpoint);
-DECLARE_string(ai_model);
-DECLARE_string(ai_api_key_jceks_secret);
-DECLARE_int32(ai_connection_timeout_s);
+DEFINE_string(ai_endpoint, "https://api.openai.com/v1/chat/completions";,
+    "The default API endpoint for an external AI engine.");
+DEFINE_validator(ai_endpoint, [](const char* name, const string& endpoint) {
+  return (impala::AiFunctions::is_api_endpoint_valid(endpoint) &&
+      impala::AiFunctions::is_api_endpoint_supported(endpoint));
+});
+
+DEFINE_string(ai_model, "gpt-4", "The default AI model used by an external AI 
engine.");
+
+DEFINE_string(ai_api_key_jceks_secret, "",
+    "The jceks secret key used for extracting the api key from configured 
keystores. "
+    "'hadoop.security.credential.provider.path' in core-site must be 
configured to "
+    "include the keystore storing the corresponding secret.");
+
+DEFINE_string(ai_additional_platforms, "",
+    "A comma-separated list of additional platforms allowed for Impala to 
access via "
+    "the AI api, formatted as 'site1,site2'.");
+
+DEFINE_int32(ai_connection_timeout_s, 10,
+    "(Advanced) The time in seconds for connection timed out when 
communicating with an "
+    "external AI engine");
+TAG_FLAG(ai_api_key_jceks_secret, sensitive);
 
 namespace impala {
 
@@ -308,4 +324,35 @@ StringVal 
AiFunctions::AiGenerateTextInternal(FunctionContext* ctx,
   return result;
 }
 
+// Template instantiations for getAuthorizationHeader function.
+#define INSTANTIATE_AI_AUTH_HEADER(PLATFORM) \
+    template Status 
getAuthorizationHeader<AiFunctions::AI_PLATFORM::PLATFORM>( \
+        string&, const std::string_view&, const 
AiFunctions::AiFunctionsOptions&);
+
+INSTANTIATE_AI_AUTH_HEADER(UNSUPPORTED)
+INSTANTIATE_AI_AUTH_HEADER(OPEN_AI)
+INSTANTIATE_AI_AUTH_HEADER(AZURE_OPEN_AI)
+INSTANTIATE_AI_AUTH_HEADER(GENERAL)
+
+#undef INSTANTIATE_AI_AUTH_HEADER
+
+// Template instantiations for AiGenerateTextInternal function.
+#define INSTANTIATE_AI_GENERATE_TEXT(FASTPATH, PLATFORM) \
+    template StringVal AiFunctions::AiGenerateTextInternal< \
+        FASTPATH, AiFunctions::AI_PLATFORM::PLATFORM>( \
+        FunctionContext*, const std::string_view&, const StringVal&, const 
StringVal&, \
+        const StringVal&, const StringVal&, const StringVal&, const bool);
+
+#define INSTANTIATE_AI_GENERATE_TEXT_FOR_PLATFORM(PLATFORM) \
+    INSTANTIATE_AI_GENERATE_TEXT(true, PLATFORM) \
+    INSTANTIATE_AI_GENERATE_TEXT(false, PLATFORM)
+
+INSTANTIATE_AI_GENERATE_TEXT_FOR_PLATFORM(UNSUPPORTED)
+INSTANTIATE_AI_GENERATE_TEXT_FOR_PLATFORM(OPEN_AI)
+INSTANTIATE_AI_GENERATE_TEXT_FOR_PLATFORM(AZURE_OPEN_AI)
+INSTANTIATE_AI_GENERATE_TEXT_FOR_PLATFORM(GENERAL)
+
+#undef INSTANTIATE_AI_GENERATE_TEXT
+#undef INSTANTIATE_AI_GENERATE_TEXT_FOR_PLATFORM
+
 } // namespace impala
diff --git a/be/src/exprs/expr-test.cc b/be/src/exprs/expr-test.cc
index d3455ecb3..b966abfa7 100644
--- a/be/src/exprs/expr-test.cc
+++ b/be/src/exprs/expr-test.cc
@@ -34,7 +34,7 @@
 #include "codegen/llvm-codegen.h"
 #include "common/init.h"
 #include "common/object-pool.h"
-#include "exprs/ai-functions.inline.h"
+#include "exprs/ai-functions.h"
 #include "exprs/anyval-util.h"
 #include "exprs/is-null-predicate.h"
 #include "exprs/like-predicate.h"
diff --git a/testdata/workloads/functional-query/queries/QueryTest/exprs.test 
b/testdata/workloads/functional-query/queries/QueryTest/exprs.test
index b79067ede..3e49dc776 100644
--- a/testdata/workloads/functional-query/queries/QueryTest/exprs.test
+++ b/testdata/workloads/functional-query/queries/QueryTest/exprs.test
@@ -3390,3 +3390,23 @@ least(cast(19.44 as decimal(4,2)), cast(18.3 as 
decimal(3,1)));
 ---- TYPES
 DECIMAL,DECIMAL,DECIMAL,DECIMAL
 ====
+---- QUERY
+select r.r_reason_desc, s.sr_return_amt
+FROM tpcds_parquet.store_returns s, tpcds_parquet.reason r
+WHERE s.sr_reason_sk=r.r_reason_sk AND s.sr_return_amt > 10000 AND
+ai_generate_text_default(CONCAT("Categorize the return reason as 'damaged',
+'not needed', 'expensive', 'incorrect order' and 'other': ", r.r_reason_desc))
+NOT IN ('other', 'not needed');
+---- TYPES
+string, INT
+====
+---- QUERY
+select r.r_reason_desc, s.sr_return_amt
+FROM tpcds_parquet.store_returns s, tpcds_parquet.reason r
+WHERE s.sr_reason_sk=r.r_reason_sk AND s.sr_return_amt > 10000 AND
+ai_generate_text("", CONCAT("Categorize the return reason as 'damaged',
+'not needed', 'expensive', 'incorrect order' and 'other': ", 
r.r_reason_desc),"","","","")
+NOT IN ('other', 'not needed');
+---- TYPES
+string, INT
+====

Reply via email to