This is an automated email from the ASF dual-hosted git repository.

arawat pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/impala.git

commit 702131b677fb08182f69ba60a82ae2fb34b5d967
Author: Yida Wu <[email protected]>
AuthorDate: Fri Nov 22 01:27:37 2024 -0800

    IMPALA-13565: Add general AI platform support to ai_generate_text
    
    Currently only OpenAI sites are allowed for ai_generate_text(),
    this patch adds support for general AI platforms to
    the ai_generate_text function. It introduces a new flag,
    ai_additional_platforms, allowing Impala to access additional
    AI platforms. For these general AI platforms, only the openai
    standard is supported, and the default api credential serves as
    the api token for general platforms.
    
    The ai_api_key_jceks_secret parameter has been renamed to
    auth_credential to support passing both plain text and jceks
    encrypted secrets.
    
    A new impala_options parameter is added to ai_generate_text() to
    enable future extensions. Adds the api_standard option to
    impala_options, with "openai" as the only supported standard.
    Adds the credential_type option to impala_options for allowing
    the plain text as the token, by default it is set to jceks.
    Adds the payload option to impala_options for customized
    payload input. If set, the request will use the provided
    customized payload directly, and the response will follow the
    openai standard for parsing. The customized payload size must not
    exceed 5MB.
    
    Adding the impala_options parameter to ai_generate_text() should
    be fine for backward compatibility, as this is a relatively new
    feature.
    
    Example:
    1. Add the site to ai_api_additional_platforms,like:
    ai_additional_platforms='new_ai.site,new_ai.com'
    2. Example sql:
    select ai_generate_text("https://new_ai.com/v1/chat/completions";,
    "hello", "model-name", "ai-api-token", "platform params",
    '{"api_standard":"openai", "credential_type":"plain",
    "payload":"payload content"}}')
    
    Tests:
    Added a new test AiFunctionsTestAdditionalSites.
    Manual tested the example with the Cloudera AI platform.
    Passed core and asan tests.
    
    Change-Id: I4ea2e1946089f262dda7ace73d5f7e37a5c98b14
    Reviewed-on: http://gerrit.cloudera.org:8080/22130
    Tested-by: Impala Public Jenkins <[email protected]>
    Reviewed-by: Abhishek Rawat <[email protected]>
---
 be/src/exprs/ai-functions-ir.cc              | 128 +++++++++---
 be/src/exprs/ai-functions.h                  |  58 ++++--
 be/src/exprs/ai-functions.inline.h           | 278 +++++++++++++++++++--------
 be/src/exprs/expr-test.cc                    | 148 ++++++++++++--
 be/src/udf/udf.cc                            |   4 +-
 be/src/udf/udf.h                             |  11 +-
 be/src/udf_samples/udf-sample.cc             |   3 +-
 common/function-registry/impala_functions.py |   2 +-
 8 files changed, 482 insertions(+), 150 deletions(-)

diff --git a/be/src/exprs/ai-functions-ir.cc b/be/src/exprs/ai-functions-ir.cc
index 2c9f17398..d5906588a 100644
--- a/be/src/exprs/ai-functions-ir.cc
+++ b/be/src/exprs/ai-functions-ir.cc
@@ -18,11 +18,16 @@
 // The functions in this file are specifically not cross-compiled to IR 
because there
 // is no signifcant performance benefit to be gained.
 
-#include <gutil/strings/util.h>
+#include <boost/algorithm/string/trim.hpp>
 
 #include "exprs/ai-functions.inline.h"
 
 using namespace impala_udf;
+using boost::algorithm::trim;
+using std::any_of;
+using std::istringstream;
+using std::set;
+using std::string_view;
 
 DEFINE_string(ai_endpoint, "https://api.openai.com/v1/chat/completions";,
     "The default API endpoint for an external AI engine.");
@@ -38,6 +43,10 @@ DEFINE_string(ai_api_key_jceks_secret, "",
     "'hadoop.security.credential.provider.path' in core-site must be 
configured to "
     "include the keystore storing the corresponding secret.");
 
+DEFINE_string(ai_additional_platforms, "",
+    "A comma-separated list of additional platforms allowed for Impala to 
access via "
+    "the AI api, formatted as 'site1,site2'.");
+
 DEFINE_int32(ai_connection_timeout_s, 10,
     "(Advanced) The time in seconds for connection timed out when 
communicating with an "
     "external AI engine");
@@ -75,27 +84,94 @@ static const char* OPEN_AI_RESPONSE_FIELD_CHOICES = 
"choices";
 static const char* OPEN_AI_RESPONSE_FIELD_MESSAGE = "message";
 static const char* OPEN_AI_RESPONSE_FIELD_CONTENT = "content";
 
-bool AiFunctions::is_api_endpoint_valid(const std::string_view& endpoint) {
+/**
+ * Singleton class for managing the additional AI platforms endpoints.
+ * The additional platforms are loaded and parsed once to optimize for 
efficiency.
+ */
+class AIAdditionalPlatforms {
+ public:
+  // Singleton accessor.
+  static AIAdditionalPlatforms& GetInstance() {
+    static AIAdditionalPlatforms instance;
+    return instance;
+  }
+
+  // Prevent copying.
+  AIAdditionalPlatforms(const AIAdditionalPlatforms&) = delete;
+  AIAdditionalPlatforms& operator=(const AIAdditionalPlatforms&) = delete;
+
+  // Check if the endpoint matches any of the additional platforms.
+  bool IsGeneralSite(const string_view& endpoint) const {
+    return any_of(additional_platforms.begin(), additional_platforms.end(),
+        [&endpoint](const string& site) {
+          return gstrncasestr(endpoint.data(), site.c_str(), endpoint.size()) 
!= nullptr;
+        });
+  }
+
+  // For testing.
+  void Reset() {
+    additional_platforms.clear();
+    ParseAdditionalSites();
+  }
+
+ private:
+  AIAdditionalPlatforms() { ParseAdditionalSites(); }
+
+  // Parse additional platforms from the flag ai_additional_platforms.
+  void ParseAdditionalSites() {
+    const string& ai_additional_platforms = FLAGS_ai_additional_platforms;
+
+    if (!ai_additional_platforms.empty()) {
+      istringstream stream(ai_additional_platforms);
+      string site;
+      LOG(INFO) << "Loading AI platform additional platforms: "
+                << ai_additional_platforms;
+
+      while (getline(stream, site, ',')) {
+        trim(site);
+        if (!site.empty()) {
+          additional_platforms.insert(site);
+          LOG(INFO) << "Loaded AI platform additional site: " << site;
+        }
+      }
+    }
+  }
+
+  // Storage of AI additional platforms;
+  set<string> additional_platforms;
+};
+
+bool AiFunctions::is_api_endpoint_valid(const string_view& endpoint) {
   // Simple validation for endpoint. It should start with https://
   return (strncaseprefix(endpoint.data(), endpoint.size(), 
AI_API_ENDPOINT_PREFIX,
               sizeof(AI_API_ENDPOINT_PREFIX))
       != nullptr);
 }
 
-bool AiFunctions::is_api_endpoint_supported(const std::string_view& endpoint) {
-  // Only OpenAI endpoints are supported.
+bool AiFunctions::is_api_endpoint_supported(const string_view& endpoint) {
+  // Only OpenAI or configured general endpoints are supported.
   return (
-      gstrncasestr(endpoint.data(), OPEN_AI_AZURE_ENDPOINT, endpoint.size()) 
!= nullptr ||
-      gstrncasestr(endpoint.data(), OPEN_AI_PUBLIC_ENDPOINT, endpoint.size()) 
!= nullptr);
+      gstrncasestr(endpoint.data(), OPEN_AI_AZURE_ENDPOINT, endpoint.size()) 
!= nullptr
+      || gstrncasestr(endpoint.data(), OPEN_AI_PUBLIC_ENDPOINT, 
endpoint.size())
+          != nullptr
+      || AIAdditionalPlatforms::GetInstance().IsGeneralSite(endpoint));
 }
 
 AiFunctions::AI_PLATFORM AiFunctions::GetAiPlatformFromEndpoint(
-    const std::string_view& endpoint) {
-  // Only OpenAI endpoints are supported.
-  if (gstrncasestr(endpoint.data(), OPEN_AI_PUBLIC_ENDPOINT, endpoint.size()) 
!= nullptr)
+    const string_view& endpoint, bool dry_run) {
+  if (UNLIKELY(dry_run)) AIAdditionalPlatforms::GetInstance().Reset();
+
+  // Only OpenAI or configured general endpoints are supported.
+  if (gstrncasestr(endpoint.data(), OPEN_AI_PUBLIC_ENDPOINT, endpoint.size())
+      != nullptr) {
     return AiFunctions::AI_PLATFORM::OPEN_AI;
-  if (gstrncasestr(endpoint.data(), OPEN_AI_AZURE_ENDPOINT, endpoint.size()) 
!= nullptr)
+  }
+  if (gstrncasestr(endpoint.data(), OPEN_AI_AZURE_ENDPOINT, endpoint.size()) 
!= nullptr) {
     return AiFunctions::AI_PLATFORM::AZURE_OPEN_AI;
+  }
+  if (AIAdditionalPlatforms::GetInstance().IsGeneralSite(endpoint)) {
+    return AI_PLATFORM::GENERAL;
+  }
   return AiFunctions::AI_PLATFORM::UNSUPPORTED;
 }
 
@@ -105,8 +181,7 @@ StringVal AiFunctions::copyErrorMessage(FunctionContext* 
ctx, const string& erro
       errorMsg.length());
 }
 
-string AiFunctions::AiGenerateTextParseOpenAiResponse(
-    const std::string_view& response) {
+string AiFunctions::AiGenerateTextParseOpenAiResponse(const string_view& 
response) {
   rapidjson::Document document;
   document.Parse(response.data(), response.size());
   // Check for parse errors
@@ -146,11 +221,12 @@ string AiFunctions::AiGenerateTextParseOpenAiResponse(
 template <bool fastpath>
 StringVal AiFunctions::AiGenerateTextHelper(FunctionContext* ctx,
     const StringVal& endpoint, const StringVal& prompt, const StringVal& model,
-    const StringVal& api_key_jceks_secret, const StringVal& params) {
-  std::string_view endpoint_sv(FLAGS_ai_endpoint);
+    const StringVal& auth_credential, const StringVal& platform_params,
+    const StringVal& impala_options) {
+  string_view endpoint_sv(FLAGS_ai_endpoint);
   // endpoint validation
   if (!fastpath && endpoint.ptr != nullptr && endpoint.len != 0) {
-    endpoint_sv = std::string_view(reinterpret_cast<char*>(endpoint.ptr), 
endpoint.len);
+    endpoint_sv = string_view(reinterpret_cast<char*>(endpoint.ptr), 
endpoint.len);
     // Simple validation for endpoint. It should start with https://
     if (!is_api_endpoint_valid(endpoint_sv)) {
       LOG(ERROR) << "AI Generate Text: \ninvalid protocol: " << endpoint_sv;
@@ -160,11 +236,15 @@ StringVal 
AiFunctions::AiGenerateTextHelper(FunctionContext* ctx,
   AI_PLATFORM platform = GetAiPlatformFromEndpoint(endpoint_sv);
   switch(platform) {
     case AI_PLATFORM::OPEN_AI:
-      return AiGenerateTextInternal<fastpath, AI_PLATFORM::OPEN_AI>(
-          ctx, endpoint_sv, prompt, model, api_key_jceks_secret, params, 
false);
+      return AiGenerateTextInternal<fastpath, AI_PLATFORM::OPEN_AI>(ctx, 
endpoint_sv,
+          prompt, model, auth_credential, platform_params, impala_options, 
false);
     case AI_PLATFORM::AZURE_OPEN_AI:
-      return AiGenerateTextInternal<fastpath, AI_PLATFORM::AZURE_OPEN_AI>(
-          ctx, endpoint_sv, prompt, model, api_key_jceks_secret, params, 
false);
+      return AiGenerateTextInternal<fastpath, AI_PLATFORM::AZURE_OPEN_AI>(ctx,
+          endpoint_sv, prompt, model, auth_credential, platform_params, 
impala_options,
+          false);
+    case AI_PLATFORM::GENERAL:
+      return AiGenerateTextInternal<fastpath, AI_PLATFORM::GENERAL>(ctx, 
endpoint_sv,
+          prompt, model, auth_credential, platform_params, impala_options, 
false);
     default:
       if (fastpath) {
         DCHECK(false) << "Default endpoint " << FLAGS_ai_endpoint << "must be 
supported";
@@ -175,16 +255,16 @@ StringVal 
AiFunctions::AiGenerateTextHelper(FunctionContext* ctx,
 }
 
 StringVal AiFunctions::AiGenerateText(FunctionContext* ctx, const StringVal& 
endpoint,
-    const StringVal& prompt, const StringVal& model,
-    const StringVal& api_key_jceks_secret, const StringVal& params) {
+    const StringVal& prompt, const StringVal& model, const StringVal& 
auth_credential,
+    const StringVal& platform_params, const StringVal& impala_options) {
   return AiGenerateTextHelper<false>(
-      ctx, endpoint, prompt, model, api_key_jceks_secret, params);
+      ctx, endpoint, prompt, model, auth_credential, platform_params, 
impala_options);
 }
 
 StringVal AiFunctions::AiGenerateTextDefault(
   FunctionContext* ctx, const StringVal& prompt) {
-  return AiGenerateTextHelper<true>(
-      ctx, NULL_STRINGVAL, prompt, NULL_STRINGVAL, NULL_STRINGVAL, 
NULL_STRINGVAL);
+  return AiGenerateTextHelper<true>(ctx, NULL_STRINGVAL, prompt, 
NULL_STRINGVAL,
+      NULL_STRINGVAL, NULL_STRINGVAL, NULL_STRINGVAL);
 }
 
 } // namespace impala
diff --git a/be/src/exprs/ai-functions.h b/be/src/exprs/ai-functions.h
index 1e3fcf8fd..d8480a93b 100644
--- a/be/src/exprs/ai-functions.h
+++ b/be/src/exprs/ai-functions.h
@@ -39,21 +39,46 @@ class AiFunctions {
   static const char* OPEN_AI_REQUEST_FIELD_CONTENT_TYPE_HEADER;
   static const char* OPEN_AI_REQUEST_AUTH_HEADER;
   static const char* AZURE_OPEN_AI_REQUEST_AUTH_HEADER;
-  enum AI_PLATFORM {
+  enum class AI_PLATFORM {
     /// Unsupported platform
     UNSUPPORTED,
     /// OpenAI public platform
     OPEN_AI,
     /// Azure OpenAI platform
-    AZURE_OPEN_AI
+    AZURE_OPEN_AI,
+    /// General AI platform
+    GENERAL
   };
-  /// Sends a prompt to the input AI endpoint using the input model, api_key 
and
-  /// optional params.
+  enum class API_STANDARD {
+    /// Unsupported standard
+    UNSUPPORTED,
+    /// OpenAI standard
+    OPEN_AI
+  };
+  enum class CREDENTIAL_TYPE {
+    /// Input credentials will be treated as plain text.
+    PLAIN,
+    /// Input credentials will be treated as a jceks secret.
+    JCEKS
+  };
+  struct AiFunctionsOptions {
+    // Default of api standard is OPEN_AI
+    AiFunctions::API_STANDARD api_standard = 
AiFunctions::API_STANDARD::OPEN_AI;
+    // Default of credential type is JCEKS.
+    AiFunctions::CREDENTIAL_TYPE credential_type = 
AiFunctions::CREDENTIAL_TYPE::JCEKS;
+    // Only valid when a customized payload is included in the request.
+    std::string_view ai_custom_payload;
+  };
+  /// Sends a prompt to the input AI endpoint using the input model, 
authentication
+  /// credential and optional platform params and impala options.
+  /// platform_params (optional) are additional AI platform specific 
parameters included
+  /// in the request sent to the AI model.
+  /// impala_options (optional) are Impala API specific options i.e 
AiFunctionsOptions.
   static StringVal AiGenerateText(FunctionContext* ctx, const StringVal& 
endpoint,
-      const StringVal& prompt, const StringVal& model,
-      const StringVal& api_key_jceks_secret, const StringVal& params);
+      const StringVal& prompt, const StringVal& model, const StringVal& 
auth_credential,
+      const StringVal& platform_params, const StringVal& impala_options);
   /// Sends a prompt to the default endpoint and uses the default model, 
default
-  /// api-key and default params.
+  /// api-key and default platform params and impala options.
   static StringVal AiGenerateTextDefault(FunctionContext* ctx, const 
StringVal& prompt);
   /// Set the ai_api_key_ member.
   static void set_api_key(string& api_key) { ai_api_key_ = api_key; }
@@ -69,27 +94,28 @@ class AiFunctions {
   /// request to the external API endpoint. If 'dry_run' is set, the POST 
request is
   /// returned. 'dry_run' mode is used only for unit tests.
   template <bool fastpath, AI_PLATFORM platform>
-  static StringVal AiGenerateTextInternal(
-      FunctionContext* ctx, const std::string_view& endpoint,
-      const StringVal& prompt, const StringVal& model,
-      const StringVal& api_key_jceks_secret, const StringVal& params, const 
bool dry_run);
+  static StringVal AiGenerateTextInternal(FunctionContext* ctx,
+      const std::string_view& endpoint, const StringVal& prompt, const 
StringVal& model,
+      const StringVal& auth_credential, const StringVal& platform_params,
+      const StringVal& impala_options, const bool dry_run);
   /// Helper function for calling AiGenerateTextInternal with common code for 
both
   /// fastpath and regular path.
   template <bool fastpath>
-  static StringVal AiGenerateTextHelper(
-    FunctionContext* ctx, const StringVal& endpoint, const StringVal& prompt,
-    const StringVal& model, const StringVal& api_key_jceks_secret,
-    const StringVal& params);
+  static StringVal AiGenerateTextHelper(FunctionContext* ctx, const StringVal& 
endpoint,
+      const StringVal& prompt, const StringVal& model, const StringVal& 
auth_credential,
+      const StringVal& platform_params, const StringVal& impala_options);
   /// Internal helper function for parsing OPEN AI's API response. Input 
parameter is the
   /// json representation of the OPEN AI's API response.
   static std::string AiGenerateTextParseOpenAiResponse(
       const std::string_view& reponse);
   /// Helper function for getting AI Platform from the endpoint
-  static AI_PLATFORM GetAiPlatformFromEndpoint(const std::string_view& 
endpoint);
+  static AI_PLATFORM GetAiPlatformFromEndpoint(
+      const std::string_view& endpoint, const bool dry_run = false);
   /// Helper functions for deep copying error message
   static StringVal copyErrorMessage(FunctionContext* ctx, const string& 
errorMsg);
 
   friend class ExprTest_AiFunctionsTest_Test;
+  friend class ExprTest_AiFunctionsTestAdditionalSites_Test;
 };
 
 } // namespace impala
diff --git a/be/src/exprs/ai-functions.inline.h 
b/be/src/exprs/ai-functions.inline.h
index bd39a5002..fd742ed2f 100644
--- a/be/src/exprs/ai-functions.inline.h
+++ b/be/src/exprs/ai-functions.inline.h
@@ -17,6 +17,7 @@
 
 #pragma once
 
+#include <gutil/strings/util.h>
 #include <rapidjson/document.h>
 #include <rapidjson/error/en.h>
 #include <rapidjson/stringbuffer.h>
@@ -50,112 +51,225 @@ namespace impala {
     }                                                      \
   } while (false)
 
-template<AiFunctions::AI_PLATFORM platform>
-Status getAuthorizationHeader(string& authHeader, const string& api_key) {
-  switch(platform) {
+// Impala Ai Functions Options Constants.
+static const char* IMPALA_AI_API_STANDARD_FIELD = "api_standard";
+static const char* IMPALA_AI_CREDENTIAL_TYPE_FIELD = "credential_type";
+static const char* IMPALA_AI_PAYLOAD_FIELD = "payload";
+static const char* IMPALA_AI_API_STANDARD_OPENAI = "openai";
+static const char* IMPALA_AI_CREDENTIAL_TYPE_PLAIN = "plain";
+static const char* IMPALA_AI_CREDENTIAL_TYPE_JCEKS = "jceks";
+static const int MAX_CUSTOM_PAYLOAD_LENGTH = 5 * 1024 * 1024; // 5MB
+
+static const size_t IMPALA_AI_API_STANDARD_OPENAI_LEN =
+    std::strlen(IMPALA_AI_API_STANDARD_OPENAI);
+static const size_t IMPALA_AI_CREDENTIAL_TYPE_PLAIN_LEN =
+    std::strlen(IMPALA_AI_CREDENTIAL_TYPE_PLAIN);
+static const size_t IMPALA_AI_CREDENTIAL_TYPE_JCEKS_LEN =
+    std::strlen(IMPALA_AI_CREDENTIAL_TYPE_JCEKS);
+
+template <AiFunctions::AI_PLATFORM platform>
+Status getAuthorizationHeader(string& authHeader, const std::string_view& 
api_key,
+    const AiFunctions::AiFunctionsOptions& ai_options) {
+  const char* header_prefix = nullptr;
+  switch (platform) {
     case AiFunctions::AI_PLATFORM::OPEN_AI:
-      authHeader = AiFunctions::OPEN_AI_REQUEST_AUTH_HEADER + api_key;
-      return Status::OK();
+      header_prefix = AiFunctions::OPEN_AI_REQUEST_AUTH_HEADER;
+      break;
     case AiFunctions::AI_PLATFORM::AZURE_OPEN_AI:
-      authHeader =  AiFunctions::AZURE_OPEN_AI_REQUEST_AUTH_HEADER + api_key;
-      return Status::OK();
+      header_prefix = AiFunctions::AZURE_OPEN_AI_REQUEST_AUTH_HEADER;
+      break;
+    case AiFunctions::AI_PLATFORM::GENERAL:
+      // For the general platform, only support OPEN_AI api standard for now.
+      if (ai_options.api_standard == AiFunctions::API_STANDARD::OPEN_AI) {
+        header_prefix = AiFunctions::OPEN_AI_REQUEST_AUTH_HEADER;
+        break;
+      }
     default:
-      DCHECK(false) <<
-          "AiGenerateTextInternal should only be called for Supported 
Platforms";
+      DCHECK(false) << "AiGenerateTextInternal should only be called for 
Supported "
+                       "Platforms and Standard";
       return Status(AiFunctions::AI_GENERATE_TXT_UNSUPPORTED_ENDPOINT_ERROR);
   }
+  DCHECK(header_prefix != nullptr);
+  authHeader = header_prefix;
+  authHeader.append(api_key);
+  return Status::OK();
+}
+
+static void ParseImpalaOptions(const StringVal& options, Document& document,
+    AiFunctions::AiFunctionsOptions& result) {
+  // If options is NULL or empty, return with defaults.
+  if (options.is_null || options.len == 0) return;
+
+  if (document.Parse(reinterpret_cast<const char*>(options.ptr), options.len)
+          .HasParseError()) {
+    std::stringstream ss;
+    ss << "Error parsing impala options: " << reinterpret_cast<const 
char*>(options.ptr)
+       << ", error code: " << document.GetParseError() << ", offset input "
+       << document.GetErrorOffset();
+    throw std::runtime_error(ss.str());
+  }
+  // Check for "api_standard" field.
+  if (document.HasMember(IMPALA_AI_API_STANDARD_FIELD)
+      && document[IMPALA_AI_API_STANDARD_FIELD].IsString()) {
+    const char* api_standard_value = 
document[IMPALA_AI_API_STANDARD_FIELD].GetString();
+    if (gstrncasestr(IMPALA_AI_API_STANDARD_OPENAI, api_standard_value,
+            IMPALA_AI_API_STANDARD_OPENAI_LEN) != nullptr) {
+      result.api_standard = AiFunctions::API_STANDARD::OPEN_AI;
+    } else {
+      result.api_standard = AiFunctions::API_STANDARD::UNSUPPORTED;
+    }
+  }
+
+  // Check for "credential_type" field.
+  if (document.HasMember(IMPALA_AI_CREDENTIAL_TYPE_FIELD)
+      && document[IMPALA_AI_CREDENTIAL_TYPE_FIELD].IsString()) {
+    const char* credential_type_value =
+        document[IMPALA_AI_CREDENTIAL_TYPE_FIELD].GetString();
+    if (gstrncasestr(IMPALA_AI_CREDENTIAL_TYPE_PLAIN, credential_type_value,
+            IMPALA_AI_CREDENTIAL_TYPE_PLAIN_LEN) != nullptr) {
+      result.credential_type = AiFunctions::CREDENTIAL_TYPE::PLAIN;
+    } else if (gstrncasestr(IMPALA_AI_CREDENTIAL_TYPE_JCEKS, 
credential_type_value,
+                   IMPALA_AI_CREDENTIAL_TYPE_JCEKS_LEN) != nullptr) {
+      result.credential_type = AiFunctions::CREDENTIAL_TYPE::JCEKS;
+    }
+  }
+
+  // Check for "payload" field.
+  if (document.HasMember(IMPALA_AI_PAYLOAD_FIELD)
+      && document[IMPALA_AI_PAYLOAD_FIELD].IsString()) {
+    const char* payload_value = document[IMPALA_AI_PAYLOAD_FIELD].GetString();
+    result.ai_custom_payload = std::string_view(payload_value);
+    // Check if payload exceeds the maximum allowed length of custom payload.
+    if (result.ai_custom_payload.length() > MAX_CUSTOM_PAYLOAD_LENGTH) {
+      std::stringstream ss;
+      ss << "Error: custom payload can't be longer than " << 
MAX_CUSTOM_PAYLOAD_LENGTH
+         << " bytes. Current length: " << result.ai_custom_payload.length();
+      result.ai_custom_payload = std::string_view();
+      throw std::runtime_error(ss.str());
+    }
+  }
 }
 
 template <bool fastpath, AiFunctions::AI_PLATFORM platform>
 StringVal AiFunctions::AiGenerateTextInternal(FunctionContext* ctx,
     const std::string_view& endpoint_sv, const StringVal& prompt, const 
StringVal& model,
-    const StringVal& api_key_jceks_secret, const StringVal& params, const bool 
dry_run) {
+    const StringVal& auth_credential, const StringVal& platform_params,
+    const StringVal& impala_options, const bool dry_run) {
   // Generate the header for the POST request
   vector<string> headers;
   headers.emplace_back(OPEN_AI_REQUEST_FIELD_CONTENT_TYPE_HEADER);
   string authHeader;
-  if (!fastpath && api_key_jceks_secret.ptr != nullptr && 
api_key_jceks_secret.len != 0) {
-    string api_key;
-    string api_key_secret(
-        reinterpret_cast<char*>(api_key_jceks_secret.ptr), 
api_key_jceks_secret.len);
-    RETURN_STRINGVAL_IF_ERROR(ctx,
-        ExecEnv::GetInstance()->frontend()->GetSecretFromKeyStore(
-            api_key_secret, &api_key));
-    RETURN_STRINGVAL_IF_ERROR(ctx,
-        getAuthorizationHeader<platform>(authHeader, api_key));
-  } else {
-    RETURN_STRINGVAL_IF_ERROR(ctx,
-        getAuthorizationHeader<platform>(authHeader, ai_api_key_));
+  AiFunctions::AiFunctionsOptions ai_options;
+  Document impala_options_document;
+
+  if (!fastpath) {
+    try {
+      ParseImpalaOptions(impala_options, impala_options_document, ai_options);
+    } catch (const std::runtime_error& e) {
+      LOG(WARNING) << AI_GENERATE_TXT_JSON_PARSE_ERROR << ": " << e.what();
+      return StringVal(AI_GENERATE_TXT_JSON_PARSE_ERROR.c_str());
+    }
   }
-  headers.emplace_back(authHeader);
-  // Generate the payload for the POST request
-  Document payload;
-  payload.SetObject();
-  Document::AllocatorType& payload_allocator = payload.GetAllocator();
-  // Azure Open AI endpoint doesn't expect model as a separate param since it's
-  // embedded in the endpoint. The 'deployment_name' below maps to a model.
-  // 
https://<resource_name>.openai.azure.com/openai/deployments/<deployment_name>/..
-  if (platform != AI_PLATFORM::AZURE_OPEN_AI) {
-    if (!fastpath && model.ptr != nullptr && model.len != 0) {
-      payload.AddMember("model",
-          rapidjson::StringRef(reinterpret_cast<char*>(model.ptr), model.len),
-          payload_allocator);
+
+  if (!fastpath && auth_credential.ptr != nullptr && auth_credential.len != 0) 
{
+    if (ai_options.credential_type == CREDENTIAL_TYPE::PLAIN) {
+      // Use the credential as a plain text token.
+      std::string_view token(
+          reinterpret_cast<char*>(auth_credential.ptr), auth_credential.len);
+      RETURN_STRINGVAL_IF_ERROR(
+          ctx, getAuthorizationHeader<platform>(authHeader, token, 
ai_options));
     } else {
-      payload.AddMember("model",
-          rapidjson::StringRef(FLAGS_ai_model.c_str(), 
FLAGS_ai_model.length()),
-          payload_allocator);
+      DCHECK(ai_options.credential_type == CREDENTIAL_TYPE::JCEKS);
+      // Use the credential as JCEKS secret and fetch API key.
+      string api_key;
+      string api_key_secret(
+          reinterpret_cast<char*>(auth_credential.ptr), auth_credential.len);
+      RETURN_STRINGVAL_IF_ERROR(ctx,
+          ExecEnv::GetInstance()->frontend()->GetSecretFromKeyStore(
+              api_key_secret, &api_key));
+      RETURN_STRINGVAL_IF_ERROR(
+          ctx, getAuthorizationHeader<platform>(authHeader, api_key, 
ai_options));
     }
+  } else {
+    RETURN_STRINGVAL_IF_ERROR(
+        ctx, getAuthorizationHeader<platform>(authHeader, ai_api_key_, 
ai_options));
   }
-  Value message_array(rapidjson::kArrayType);
-  Value message(rapidjson::kObjectType);
-  message.AddMember("role", "user", payload_allocator);
-  if (prompt.ptr == nullptr || prompt.len == 0) {
-    return StringVal(AI_GENERATE_TXT_INVALID_PROMPT_ERROR.c_str());
-  }
-  message.AddMember("content",
-      rapidjson::StringRef(reinterpret_cast<char*>(prompt.ptr), prompt.len),
-      payload_allocator);
-  message_array.PushBack(message, payload_allocator);
-  payload.AddMember("messages", message_array, payload_allocator);
-  // Override additional params.
-  // Caution: 'payload' might reference data owned by 'overrides'.
-  // To ensure valid access, place 'overrides' outside the 'if'
-  // statement before using 'payload'.
-  Document overrides;
-  if (!fastpath && params.ptr != nullptr && params.len != 0) {
-    overrides.Parse(reinterpret_cast<char*>(params.ptr), params.len);
-    if (overrides.HasParseError()) {
-      LOG(WARNING) << AI_GENERATE_TXT_JSON_PARSE_ERROR << ": error code "
-                   << overrides.GetParseError() << ", offset input "
-                   << overrides.GetErrorOffset();
-      return StringVal(AI_GENERATE_TXT_JSON_PARSE_ERROR.c_str());
+  headers.emplace_back(authHeader);
+
+  string payload_str;
+  if (!fastpath && !ai_options.ai_custom_payload.empty()) {
+    payload_str =
+        string(ai_options.ai_custom_payload.data(), 
ai_options.ai_custom_payload.size());
+  } else {
+    // Generate the payload for the POST request
+    Document payload;
+    payload.SetObject();
+    Document::AllocatorType& payload_allocator = payload.GetAllocator();
+    // Azure Open AI endpoint doesn't expect model as a separate param since 
it's
+    // embedded in the endpoint. The 'deployment_name' below maps to a model.
+    // 
https://<resource_name>.openai.azure.com/openai/deployments/<deployment_name>/..
+    if (platform != AI_PLATFORM::AZURE_OPEN_AI) {
+      if (!fastpath && model.ptr != nullptr && model.len != 0) {
+        payload.AddMember("model",
+            rapidjson::StringRef(reinterpret_cast<char*>(model.ptr), 
model.len),
+            payload_allocator);
+      } else {
+        payload.AddMember("model",
+            rapidjson::StringRef(FLAGS_ai_model.c_str(), 
FLAGS_ai_model.length()),
+            payload_allocator);
+      }
     }
-    for (auto& m : overrides.GetObject()) {
-      if (payload.HasMember(m.name.GetString())) {
-        if (m.name == "messages") {
-          LOG(WARNING)
-              << AI_GENERATE_TXT_JSON_PARSE_ERROR
-              << ": 'messages' is constructed from 'prompt', cannot be 
overridden";
-          return 
StringVal(AI_GENERATE_TXT_MSG_OVERRIDE_FORBIDDEN_ERROR.c_str());
+    Value message_array(rapidjson::kArrayType);
+    Value message(rapidjson::kObjectType);
+    message.AddMember("role", "user", payload_allocator);
+    if (prompt.ptr == nullptr || prompt.len == 0) {
+      return StringVal(AI_GENERATE_TXT_INVALID_PROMPT_ERROR.c_str());
+    }
+    message.AddMember("content",
+        rapidjson::StringRef(reinterpret_cast<char*>(prompt.ptr), prompt.len),
+        payload_allocator);
+    message_array.PushBack(message, payload_allocator);
+    payload.AddMember("messages", message_array, payload_allocator);
+    // Override additional platform params.
+    // Caution: 'payload' might reference data owned by 'overrides'.
+    // To ensure valid access, place 'overrides' outside the 'if'
+    // statement before using 'payload'.
+    Document overrides;
+    if (!fastpath && platform_params.ptr != nullptr && platform_params.len != 
0) {
+      overrides.Parse(reinterpret_cast<char*>(platform_params.ptr), 
platform_params.len);
+      if (overrides.HasParseError()) {
+        LOG(WARNING) << AI_GENERATE_TXT_JSON_PARSE_ERROR << ": error code "
+                     << overrides.GetParseError() << ", offset input "
+                     << overrides.GetErrorOffset();
+        return StringVal(AI_GENERATE_TXT_JSON_PARSE_ERROR.c_str());
+      }
+      for (auto& m : overrides.GetObject()) {
+        if (payload.HasMember(m.name.GetString())) {
+          if (m.name == "messages") {
+            LOG(WARNING)
+                << AI_GENERATE_TXT_JSON_PARSE_ERROR
+                << ": 'messages' is constructed from 'prompt', cannot be 
overridden";
+            return 
StringVal(AI_GENERATE_TXT_MSG_OVERRIDE_FORBIDDEN_ERROR.c_str());
+          } else {
+            payload[m.name.GetString()] = m.value;
+          }
         } else {
-          payload[m.name.GetString()] = m.value;
-        }
-      } else {
-        if ((m.name == "n") && !(m.value.IsInt() && m.value.GetInt() == 1)) {
-          LOG(WARNING)
-              << AI_GENERATE_TXT_JSON_PARSE_ERROR
-              << ": 'n' must be of integer type and have value 1";
-          return StringVal(AI_GENERATE_TXT_N_OVERRIDE_FORBIDDEN_ERROR.c_str());
+          if ((m.name == "n") && !(m.value.IsInt() && m.value.GetInt() == 1)) {
+            LOG(WARNING) << AI_GENERATE_TXT_JSON_PARSE_ERROR
+                         << ": 'n' must be of integer type and have value 1";
+            return 
StringVal(AI_GENERATE_TXT_N_OVERRIDE_FORBIDDEN_ERROR.c_str());
+          }
+          payload.AddMember(m.name, m.value, payload_allocator);
         }
-        payload.AddMember(m.name, m.value, payload_allocator);
       }
     }
+    // Convert payload into string for POST request
+    StringBuffer buffer;
+    Writer<StringBuffer> writer(buffer);
+    payload.Accept(writer);
+    payload_str = string(buffer.GetString(), buffer.GetSize());
   }
-  // Convert payload into string for POST request
-  StringBuffer buffer;
-  Writer<StringBuffer> writer(buffer);
-  payload.Accept(writer);
-  string payload_str(buffer.GetString(), buffer.GetSize());
+  DCHECK(!payload_str.empty());
   VLOG(2) << "AI Generate Text: \nendpoint: " << endpoint_sv
           << " \npayload: " << payload_str;
   if (UNLIKELY(dry_run)) {
diff --git a/be/src/exprs/expr-test.cc b/be/src/exprs/expr-test.cc
index f186fe1bd..d3455ecb3 100644
--- a/be/src/exprs/expr-test.cc
+++ b/be/src/exprs/expr-test.cc
@@ -85,6 +85,7 @@ DECLARE_bool(disable_optimization_passes);
 DECLARE_string(hdfs_zone_info_zip);
 DECLARE_string(ai_endpoint);
 DECLARE_string(ai_model);
+DECLARE_string(ai_additional_platforms);
 
 namespace posix_time = boost::posix_time;
 using boost::bad_lexical_cast;
@@ -11299,6 +11300,8 @@ TEST_P(ExprTest, AiFunctionsTest) {
   StringVal prompt("hello!");
   // additional params
   StringVal json_params;
+  // impala options.
+  StringVal impala_options;
   // dry_run to receive HTTP request header and body
   bool dry_run = true;
 
@@ -11312,9 +11315,9 @@ TEST_P(ExprTest, AiFunctionsTest) {
 
   // Test fastpath
   StringVal result =
-    AiFunctions::AiGenerateTextInternal<true, 
AiFunctions::AI_PLATFORM::OPEN_AI>(
-        ctx, FLAGS_ai_endpoint, prompt, StringVal::null(), StringVal::null(),
-        StringVal::null(), dry_run);
+      AiFunctions::AiGenerateTextInternal<true, 
AiFunctions::AI_PLATFORM::OPEN_AI>(ctx,
+          FLAGS_ai_endpoint, prompt, StringVal::null(), StringVal::null(),
+          StringVal::null(), StringVal::null(), dry_run);
   EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
       string("https://api.openai.com/v1/chat/completions";
              "\nContent-Type: application/json"
@@ -11323,9 +11326,9 @@ TEST_P(ExprTest, AiFunctionsTest) {
              "\"hello!\"}]}"));
 
   result =
-    AiFunctions::AiGenerateTextInternal<true, 
AiFunctions::AI_PLATFORM::AZURE_OPEN_AI>(
-        ctx, azure_openai_endpoint, prompt, StringVal::null(), 
StringVal::null(),
-        StringVal::null(), dry_run);
+      AiFunctions::AiGenerateTextInternal<true, 
AiFunctions::AI_PLATFORM::AZURE_OPEN_AI>(
+          ctx, azure_openai_endpoint, prompt, StringVal::null(), 
StringVal::null(),
+          StringVal::null(), StringVal::null(), dry_run);
   EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
       string("https://resource.openai.azure.com/openai/deployments/";
              "deployment/completions?api-version=2024-02-01"
@@ -11336,18 +11339,19 @@ TEST_P(ExprTest, AiFunctionsTest) {
 
   // Test endpoints.
   // endpoints must begin with https.
-  result = AiFunctions::AiGenerateText(
-      ctx, StringVal("http://ai.com";), prompt, model, jceks_secret, 
json_params);
+  result = AiFunctions::AiGenerateText(ctx, StringVal("http://ai.com";), 
prompt, model,
+      jceks_secret, json_params, impala_options);
   EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
       AiFunctions::AI_GENERATE_TXT_INVALID_PROTOCOL_ERROR);
   // only OpenAI endpoints are supported.
   result = AiFunctions::AiGenerateText(
-      ctx, "https://ai.com";, prompt, model, jceks_secret, json_params);
+      ctx, "https://ai.com";, prompt, model, jceks_secret, json_params, 
impala_options);
   EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
       AiFunctions::AI_GENERATE_TXT_UNSUPPORTED_ENDPOINT_ERROR);
   // valid request using OpenAI endpoint.
   result = AiFunctions::AiGenerateTextInternal<false, 
AiFunctions::AI_PLATFORM::OPEN_AI>(
-      ctx, openai_endpoint, prompt, model, jceks_secret, json_params, dry_run);
+      ctx, openai_endpoint, prompt, model, jceks_secret, json_params, 
impala_options,
+      dry_run);
   EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
       string("https://api.openai.com/v1/chat/completions";
              "\nContent-Type: application/json"
@@ -11359,7 +11363,8 @@ TEST_P(ExprTest, AiFunctionsTest) {
   // prompt cannot be empty.
   StringVal invalid_prompt("");
   result = AiFunctions::AiGenerateTextInternal<false, 
AiFunctions::AI_PLATFORM::OPEN_AI>(
-      ctx, openai_endpoint, invalid_prompt, model, jceks_secret, json_params, 
dry_run);
+      ctx, openai_endpoint, invalid_prompt, model, jceks_secret, json_params,
+      impala_options, dry_run);
   EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
       AiFunctions::AI_GENERATE_TXT_INVALID_PROMPT_ERROR);
   result = AiFunctions::AiGenerateTextDefault(ctx, invalid_prompt);
@@ -11368,7 +11373,8 @@ TEST_P(ExprTest, AiFunctionsTest) {
   // prompt cannot be null.
   invalid_prompt = StringVal::null();
   result = AiFunctions::AiGenerateTextInternal<false, 
AiFunctions::AI_PLATFORM::OPEN_AI>(
-      ctx, openai_endpoint, invalid_prompt, model, jceks_secret, json_params, 
dry_run);
+      ctx, openai_endpoint, invalid_prompt, model, jceks_secret, json_params,
+      impala_options, dry_run);
   EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
       AiFunctions::AI_GENERATE_TXT_INVALID_PROMPT_ERROR);
   result = AiFunctions::AiGenerateTextDefault(ctx, invalid_prompt);
@@ -11379,7 +11385,8 @@ TEST_P(ExprTest, AiFunctionsTest) {
   // invalid json results in error.
   StringVal invalid_json_params("{\"temperature\": 0.49, \"stop\": 
[\"*\",::,]}");
   result = AiFunctions::AiGenerateTextInternal<false, 
AiFunctions::AI_PLATFORM::OPEN_AI>(
-      ctx, openai_endpoint, prompt, model, jceks_secret, invalid_json_params, 
dry_run);
+      ctx, openai_endpoint, prompt, model, jceks_secret, invalid_json_params,
+      impala_options, dry_run);
   EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
       AiFunctions::AI_GENERATE_TXT_JSON_PARSE_ERROR);
   // valid json results in overriding existing params ('model'), and adding 
new parms
@@ -11387,7 +11394,8 @@ TEST_P(ExprTest, AiFunctionsTest) {
   StringVal valid_json_params(
       "{\"model\": \"gpt\", \"temperature\": 0.49, \"stop\": [\"*\", \"%\"]}");
   result = AiFunctions::AiGenerateTextInternal<false, 
AiFunctions::AI_PLATFORM::OPEN_AI>(
-      ctx, openai_endpoint, prompt, model, jceks_secret, valid_json_params, 
dry_run);
+      ctx, openai_endpoint, prompt, model, jceks_secret, valid_json_params,
+      impala_options, dry_run);
   EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
       string("https://api.openai.com/v1/chat/completions";
              "\nContent-Type: application/json"
@@ -11398,25 +11406,29 @@ TEST_P(ExprTest, AiFunctionsTest) {
   StringVal forbidden_msg_override(
       "{\"messages\": [{\"role\":\"system\",\"content\":\"howdy!\"}]}");
   result = AiFunctions::AiGenerateTextInternal<false, 
AiFunctions::AI_PLATFORM::OPEN_AI>(
-      ctx, openai_endpoint, prompt, model, jceks_secret, 
forbidden_msg_override, dry_run);
+      ctx, openai_endpoint, prompt, model, jceks_secret, 
forbidden_msg_override,
+      impala_options, dry_run);
   EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
       AiFunctions::AI_GENERATE_TXT_MSG_OVERRIDE_FORBIDDEN_ERROR);
   // 'n != 1' cannot be overriden as additional params
   StringVal forbidden_n_value("{\"n\": 2}");
   result = AiFunctions::AiGenerateTextInternal<false, 
AiFunctions::AI_PLATFORM::OPEN_AI>(
-      ctx, openai_endpoint, prompt, model, jceks_secret, forbidden_n_value, 
dry_run);
+      ctx, openai_endpoint, prompt, model, jceks_secret, forbidden_n_value,
+      impala_options, dry_run);
   EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
       AiFunctions::AI_GENERATE_TXT_N_OVERRIDE_FORBIDDEN_ERROR);
   // non integer value of 'n' cannot be overriden as additional params
   StringVal forbidden_n_type("{\"n\": \"1\"}");
   result = AiFunctions::AiGenerateTextInternal<false, 
AiFunctions::AI_PLATFORM::OPEN_AI>(
-      ctx, openai_endpoint, prompt, model, jceks_secret, forbidden_n_type, 
dry_run);
+      ctx, openai_endpoint, prompt, model, jceks_secret, forbidden_n_type, 
impala_options,
+      dry_run);
   EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
       AiFunctions::AI_GENERATE_TXT_N_OVERRIDE_FORBIDDEN_ERROR);
   // accept 'n=1' override as additional params
   StringVal allowed_n_override("{\"n\": 1}");
   result = AiFunctions::AiGenerateTextInternal<false, 
AiFunctions::AI_PLATFORM::OPEN_AI>(
-      ctx, openai_endpoint, prompt, model, jceks_secret, allowed_n_override, 
dry_run);
+      ctx, openai_endpoint, prompt, model, jceks_secret, allowed_n_override,
+      impala_options, dry_run);
   EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
       string("https://api.openai.com/v1/chat/completions";
              "\nContent-Type: application/json"
@@ -11427,7 +11439,7 @@ TEST_P(ExprTest, AiFunctionsTest) {
   // Test flag file options are used when input is empty/null
   result = AiFunctions::AiGenerateTextInternal<false, 
AiFunctions::AI_PLATFORM::OPEN_AI>(
       ctx, FLAGS_ai_endpoint, prompt, StringVal::null(), jceks_secret, 
json_params,
-      dry_run);
+      impala_options, dry_run);
   EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
       string("https://api.openai.com/v1/chat/completions";
              "\nContent-Type: application/json"
@@ -11435,7 +11447,8 @@ TEST_P(ExprTest, AiFunctionsTest) {
              
"\n{\"model\":\"gpt-4\",\"messages\":[{\"role\":\"user\",\"content\":"
              "\"hello!\"}]}"));
   result = AiFunctions::AiGenerateTextInternal<false, 
AiFunctions::AI_PLATFORM::OPEN_AI>(
-      ctx, FLAGS_ai_endpoint, prompt, StringVal(""), jceks_secret, 
json_params, dry_run);
+      ctx, FLAGS_ai_endpoint, prompt, StringVal(""), jceks_secret, json_params,
+      impala_options, dry_run);
   EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
       string("https://api.openai.com/v1/chat/completions";
              "\nContent-Type: application/json"
@@ -11443,6 +11456,75 @@ TEST_P(ExprTest, AiFunctionsTest) {
              
"\n{\"model\":\"gpt-4\",\"messages\":[{\"role\":\"user\",\"content\":"
              "\"hello!\"}]}"));
 
+  // Test Impala options.
+  StringVal impala_options_payload("{\"payload\":\"testpayload\"}");
+  result = AiFunctions::AiGenerateTextInternal<false, 
AiFunctions::AI_PLATFORM::OPEN_AI>(
+      ctx, openai_endpoint, prompt, model, jceks_secret, json_params,
+      impala_options_payload, dry_run);
+  EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
+      string("https://api.openai.com/v1/chat/completions";
+             "\nContent-Type: application/json"
+             "\nAuthorization: Bearer do_not_share"
+             "\ntestpayload"));
+
+  // Test a not supported Impala option, doesn't affect the results.
+  StringVal impala_options_payload_extra(
+      "{\"payload\":\"testpayload\", "
+      "\"not_supported_key\":\"not_supported_content\"}");
+  result = AiFunctions::AiGenerateTextInternal<false, 
AiFunctions::AI_PLATFORM::OPEN_AI>(
+      ctx, openai_endpoint, prompt, model, jceks_secret, json_params,
+      impala_options_payload_extra, dry_run);
+  EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
+      string("https://api.openai.com/v1/chat/completions";
+             "\nContent-Type: application/json"
+             "\nAuthorization: Bearer do_not_share"
+             "\ntestpayload"));
+
+  // Test an Impala option with malformatted json.
+  StringVal impala_options_mal_formatted("{\"payload\":\"testpayload\", "
+                                         
"malformatted_key:\"malformatted_content}");
+  result = AiFunctions::AiGenerateTextInternal<false, 
AiFunctions::AI_PLATFORM::OPEN_AI>(
+      ctx, openai_endpoint, prompt, model, jceks_secret, json_params,
+      impala_options_mal_formatted, dry_run);
+  EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
+      AiFunctions::AI_GENERATE_TXT_JSON_PARSE_ERROR);
+
+  // Test Impala options with payload exceeding 5MB.
+  string large_string(5 * 1024 * 1024 + 1, 'A');
+  string large_payload = "{\"payload\":\"" + large_string + "\"}";
+  StringVal impala_options_long(large_payload.c_str());
+  result = AiFunctions::AiGenerateTextInternal<false, 
AiFunctions::AI_PLATFORM::OPEN_AI>(
+      ctx, openai_endpoint, prompt, model, jceks_secret, json_params, 
impala_options_long,
+      dry_run);
+  EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
+      AiFunctions::AI_GENERATE_TXT_JSON_PARSE_ERROR);
+
+  // Test PLAIN credential type.
+  StringVal plain_token("test_token");
+  StringVal plain_token_options(
+      "{\"credential_type\":\"plain\",\"api_standard\":\"openai\"}");
+  result = AiFunctions::AiGenerateTextInternal<false, 
AiFunctions::AI_PLATFORM::OPEN_AI>(
+      ctx, openai_endpoint, prompt, model, plain_token, json_params, 
plain_token_options,
+      dry_run);
+  EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
+      string("https://api.openai.com/v1/chat/completions";
+             "\nContent-Type: application/json"
+             "\nAuthorization: Bearer test_token"
+             
"\n{\"model\":\"bot\",\"messages\":[{\"role\":\"user\",\"content\":"
+             "\"hello!\"}]}"));
+
+  // Test PLAIN credential type with customized payload.
+  plain_token_options = 
StringVal("{\"credential_type\":\"plain\",\"api_standard\":"
+                                  "\"openai\", \"payload\":\"testpayload\"}");
+  result = AiFunctions::AiGenerateTextInternal<false, 
AiFunctions::AI_PLATFORM::OPEN_AI>(
+      ctx, openai_endpoint, prompt, model, plain_token, json_params, 
plain_token_options,
+      dry_run);
+  EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
+      string("https://api.openai.com/v1/chat/completions";
+             "\nContent-Type: application/json"
+             "\nAuthorization: Bearer test_token"
+             "\ntestpayload"));
+
   // Test OPEN AI's API response parsing
   string content(
       "A null-terminated string is a character string in a programming "
@@ -11485,6 +11567,32 @@ TEST_P(ExprTest, AiFunctionsTest) {
   state.ReleaseResources();
 }
 
+TEST_P(ExprTest, AiFunctionsTestAdditionalSites) {
+  FLAGS_ai_additional_platforms = "ai-api.com , another-ai.org";
+  // Test existing endpoints.
+  EXPECT_EQ(AiFunctions::GetAiPlatformFromEndpoint(
+                "https://api.openai.com/v1/chat/completions";, true),
+      AiFunctions::AI_PLATFORM::OPEN_AI);
+  EXPECT_EQ(AiFunctions::GetAiPlatformFromEndpoint(
+                "https://openai.azure.com/openai/deployments/";, true),
+      AiFunctions::AI_PLATFORM::AZURE_OPEN_AI);
+
+  // Test additional added GENERAL ai platform sites.
+  EXPECT_EQ(
+      AiFunctions::GetAiPlatformFromEndpoint("https://ai-api.com/v1/generate";, 
true),
+      AiFunctions::AI_PLATFORM::GENERAL);
+  EXPECT_EQ(
+      
AiFunctions::GetAiPlatformFromEndpoint("https://another-ai.org/completions";, 
true),
+      AiFunctions::AI_PLATFORM::GENERAL);
+  // Test unsupported endpoint.
+  EXPECT_EQ(AiFunctions::GetAiPlatformFromEndpoint("https://random-api.site";, 
true),
+      AiFunctions::AI_PLATFORM::UNSUPPORTED);
+  // Test case sensitivity.
+  EXPECT_EQ(
+      AiFunctions::GetAiPlatformFromEndpoint("https://AI-API.COM/v1/generate";, 
true),
+      AiFunctions::AI_PLATFORM::GENERAL);
+}
+
 } // namespace impala
 
 INSTANTIATE_TEST_SUITE_P(Instantiations, ExprTest, ::testing::Values(
diff --git a/be/src/udf/udf.cc b/be/src/udf/udf.cc
index 10390824f..51f0ec933 100644
--- a/be/src/udf/udf.cc
+++ b/be/src/udf/udf.cc
@@ -123,8 +123,8 @@ using impala_udf::FunctionContext;
 class AiFunctions {
  public:
   static StringVal AiGenerateText(FunctionContext* ctx, const StringVal& 
endpoint,
-      const StringVal& prompt, const StringVal& model,
-      const StringVal& api_key_jceks_secret, const StringVal& params) {
+      const StringVal& prompt, const StringVal& model, const StringVal& 
auth_credential,
+      const StringVal& params, const StringVal& impala_options) {
     return StringVal(AI_FUNCTIONS_DUMMY_RESPONSE.c_str());
   }
   static StringVal AiGenerateTextDefault(FunctionContext* ctx, const 
StringVal& prompt) {
diff --git a/be/src/udf/udf.h b/be/src/udf/udf.h
index f997f7a07..c96f4731e 100644
--- a/be/src/udf/udf.h
+++ b/be/src/udf/udf.h
@@ -67,14 +67,17 @@ class FunctionContext;
 struct BuiltInFunctions {
  public:
   /// Sends a prompt to the default endpoint and uses the default model, 
default
-  /// jceks api-key secret and default params.
+  /// auth credentials and default platform params and impala options.
   StringVal (*ai_generate_text_default)(
       FunctionContext* context, const StringVal& prompt);
-  /// Sends a prompt to the input AI endpoint using the input model, jceks 
api_key secret
-  /// and optional params.
+  /// Sends a prompt to the input AI endpoint using the input model, 
authentication
+  /// credential and optional platform params and impala options.
+  /// The authentication credential can be a jceks api_key secret or plain text
+  /// depending on the specific scenario.
   StringVal (*ai_generate_text)(FunctionContext* context, const StringVal& 
endpoint,
       const StringVal& prompt, const StringVal& model,
-      const StringVal& api_key_jceks_secret, const StringVal& params);
+      const StringVal& api_auth_credential, const StringVal& platform_params,
+      const StringVal& impala_options);
 };
 
 /// A FunctionContext is passed to every UDF/UDA and is the interface for the 
UDF to the
diff --git a/be/src/udf_samples/udf-sample.cc b/be/src/udf_samples/udf-sample.cc
index 2c3d380ec..a8ca5f8b0 100644
--- a/be/src/udf_samples/udf-sample.cc
+++ b/be/src/udf_samples/udf-sample.cc
@@ -51,6 +51,7 @@ StringVal ClassifyReviews(FunctionContext* context, const 
StringVal& input) {
   const StringVal model("gpt-3.5-turbo");
   const StringVal api_key_jceks_secret("open-ai-key");
   const StringVal params("{\"temperature\": 0.9, \"model\": \"gpt-4\"}");
+  const StringVal options("{\"credential_type\": \"JCEKS\"}");
   return context->Functions()->ai_generate_text(
-      context, endpoint, prompt, model, api_key_jceks_secret, params);
+      context, endpoint, prompt, model, api_key_jceks_secret, params, options);
 }
diff --git a/common/function-registry/impala_functions.py 
b/common/function-registry/impala_functions.py
index 76a3dc4e1..f2ed85f97 100644
--- a/common/function-registry/impala_functions.py
+++ b/common/function-registry/impala_functions.py
@@ -517,7 +517,7 @@ visible_functions = [
   [['repeat'], 'STRING', ['STRING', 'BIGINT'], 
'impala::StringFunctions::Repeat'],
   [['lpad'], 'STRING', ['STRING', 'BIGINT', 'STRING'], 
'impala::StringFunctions::Lpad'],
   [['rpad'], 'STRING', ['STRING', 'BIGINT', 'STRING'], 
'impala::StringFunctions::Rpad'],
-  [['ai_generate_text'], 'STRING', ['STRING', 'STRING', 'STRING', 'STRING', 
'STRING'],
+  [['ai_generate_text'], 'STRING', ['STRING', 'STRING', 'STRING', 'STRING', 
'STRING', 'STRING'],
    'impala::AiFunctions::AiGenerateText'],
   [['ai_generate_text_default'], 'STRING', ['STRING'],
    'impala::AiFunctions::AiGenerateTextDefault'],

Reply via email to