This is an automated email from the ASF dual-hosted git repository.

arawat pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/impala.git


The following commit(s) were added to refs/heads/master by this push:
     new 3668a9517 IMPALA-13131: Azure OpenAI API expects 'api-key' instead of 
'Authorization' in the request header
3668a9517 is described below

commit 3668a9517c4d8097591ed3b6fa672bf87faa77f6
Author: Abhishek Rawat <[email protected]>
AuthorDate: Fri Jun 7 07:13:58 2024 -0700

    IMPALA-13131: Azure OpenAI API expects 'api-key' instead of 'Authorization' 
in the request header
    
    Updated the POST request when communicating with Azure Open AI
    endpoint. The header now includes 'api-key: <api-key>' instead of
    'Authorization: Bearer <api-key>'.
    
    Also, removed 'model' as a required param for the Azure Open AI api
    call. This is mainly because the endpoint contains deployment which
    is basically already mapped to a model.
    
    Testing:
    - Updated existing unit test as per the Azure API reference
    - Manually tested builtin 'ai_generate_text' using an Azure Open AI
    deployment.
    
    Change-Id: If9cc07940ce355d511bcf0ee615ff31042d13eb5
    Reviewed-on: http://gerrit.cloudera.org:8080/21493
    Reviewed-by: Impala Public Jenkins <[email protected]>
    Tested-by: Impala Public Jenkins <[email protected]>
---
 be/src/exprs/ai-functions-ir.cc    | 59 +++++++++++++++++++++++--
 be/src/exprs/ai-functions.h        | 26 ++++++++++-
 be/src/exprs/ai-functions.inline.h | 88 +++++++++++++++++++++-----------------
 be/src/exprs/expr-test.cc          | 78 +++++++++++++++++++++++----------
 4 files changed, 183 insertions(+), 68 deletions(-)

diff --git a/be/src/exprs/ai-functions-ir.cc b/be/src/exprs/ai-functions-ir.cc
index 6def1a010..2c9f17398 100644
--- a/be/src/exprs/ai-functions-ir.cc
+++ b/be/src/exprs/ai-functions-ir.cc
@@ -60,6 +60,10 @@ const string 
AiFunctions::AI_GENERATE_TXT_N_OVERRIDE_FORBIDDEN_ERROR =
 string AiFunctions::ai_api_key_;
 const char* AiFunctions::OPEN_AI_REQUEST_FIELD_CONTENT_TYPE_HEADER =
     "Content-Type: application/json";
+const char* AiFunctions::OPEN_AI_REQUEST_AUTH_HEADER =
+    "Authorization: Bearer ";
+const char* AiFunctions::AZURE_OPEN_AI_REQUEST_AUTH_HEADER =
+    "api-key: ";
 
 // other constants
 static const StringVal NULL_STRINGVAL = StringVal::null();
@@ -85,6 +89,22 @@ bool AiFunctions::is_api_endpoint_supported(const 
std::string_view& endpoint) {
       gstrncasestr(endpoint.data(), OPEN_AI_PUBLIC_ENDPOINT, endpoint.size()) 
!= nullptr);
 }
 
+AiFunctions::AI_PLATFORM AiFunctions::GetAiPlatformFromEndpoint(
+    const std::string_view& endpoint) {
+  // Only OpenAI endpoints are supported.
+  if (gstrncasestr(endpoint.data(), OPEN_AI_PUBLIC_ENDPOINT, endpoint.size()) 
!= nullptr)
+    return AiFunctions::AI_PLATFORM::OPEN_AI;
+  if (gstrncasestr(endpoint.data(), OPEN_AI_AZURE_ENDPOINT, endpoint.size()) 
!= nullptr)
+    return AiFunctions::AI_PLATFORM::AZURE_OPEN_AI;
+  return AiFunctions::AI_PLATFORM::UNSUPPORTED;
+}
+
+StringVal AiFunctions::copyErrorMessage(FunctionContext* ctx, const string& 
errorMsg) {
+  return StringVal::CopyFrom(ctx,
+      reinterpret_cast<const uint8_t*>(errorMsg.c_str()),
+      errorMsg.length());
+}
+
 string AiFunctions::AiGenerateTextParseOpenAiResponse(
     const std::string_view& response) {
   rapidjson::Document document;
@@ -123,17 +143,48 @@ string AiFunctions::AiGenerateTextParseOpenAiResponse(
   return message[OPEN_AI_RESPONSE_FIELD_CONTENT].GetString();
 }
 
+template <bool fastpath>
+StringVal AiFunctions::AiGenerateTextHelper(FunctionContext* ctx,
+    const StringVal& endpoint, const StringVal& prompt, const StringVal& model,
+    const StringVal& api_key_jceks_secret, const StringVal& params) {
+  std::string_view endpoint_sv(FLAGS_ai_endpoint);
+  // endpoint validation
+  if (!fastpath && endpoint.ptr != nullptr && endpoint.len != 0) {
+    endpoint_sv = std::string_view(reinterpret_cast<char*>(endpoint.ptr), 
endpoint.len);
+    // Simple validation for endpoint. It should start with https://
+    if (!is_api_endpoint_valid(endpoint_sv)) {
+      LOG(ERROR) << "AI Generate Text: \ninvalid protocol: " << endpoint_sv;
+      return StringVal(AI_GENERATE_TXT_INVALID_PROTOCOL_ERROR.c_str());
+    }
+  }
+  AI_PLATFORM platform = GetAiPlatformFromEndpoint(endpoint_sv);
+  switch(platform) {
+    case AI_PLATFORM::OPEN_AI:
+      return AiGenerateTextInternal<fastpath, AI_PLATFORM::OPEN_AI>(
+          ctx, endpoint_sv, prompt, model, api_key_jceks_secret, params, 
false);
+    case AI_PLATFORM::AZURE_OPEN_AI:
+      return AiGenerateTextInternal<fastpath, AI_PLATFORM::AZURE_OPEN_AI>(
+          ctx, endpoint_sv, prompt, model, api_key_jceks_secret, params, 
false);
+    default:
+      if (fastpath) {
+        DCHECK(false) << "Default endpoint " << FLAGS_ai_endpoint << "must be 
supported";
+      }
+      LOG(ERROR) << "AI Generate Text: \nunsupported endpoint: " << 
endpoint_sv;
+      return StringVal(AI_GENERATE_TXT_UNSUPPORTED_ENDPOINT_ERROR.c_str());
+  }
+}
+
 StringVal AiFunctions::AiGenerateText(FunctionContext* ctx, const StringVal& 
endpoint,
     const StringVal& prompt, const StringVal& model,
     const StringVal& api_key_jceks_secret, const StringVal& params) {
-  return AiGenerateTextInternal<false>(
-      ctx, endpoint, prompt, model, api_key_jceks_secret, params, false);
+  return AiGenerateTextHelper<false>(
+      ctx, endpoint, prompt, model, api_key_jceks_secret, params);
 }
 
 StringVal AiFunctions::AiGenerateTextDefault(
   FunctionContext* ctx, const StringVal& prompt) {
-  return AiGenerateTextInternal<true>(
-      ctx, NULL_STRINGVAL, prompt, NULL_STRINGVAL, NULL_STRINGVAL, 
NULL_STRINGVAL, false);
+  return AiGenerateTextHelper<true>(
+      ctx, NULL_STRINGVAL, prompt, NULL_STRINGVAL, NULL_STRINGVAL, 
NULL_STRINGVAL);
 }
 
 } // namespace impala
diff --git a/be/src/exprs/ai-functions.h b/be/src/exprs/ai-functions.h
index 0e6396b40..1e3fcf8fd 100644
--- a/be/src/exprs/ai-functions.h
+++ b/be/src/exprs/ai-functions.h
@@ -37,6 +37,16 @@ class AiFunctions {
   static const string AI_GENERATE_TXT_MSG_OVERRIDE_FORBIDDEN_ERROR;
   static const string AI_GENERATE_TXT_N_OVERRIDE_FORBIDDEN_ERROR;
   static const char* OPEN_AI_REQUEST_FIELD_CONTENT_TYPE_HEADER;
+  static const char* OPEN_AI_REQUEST_AUTH_HEADER;
+  static const char* AZURE_OPEN_AI_REQUEST_AUTH_HEADER;
+  enum AI_PLATFORM {
+    /// Unsupported platform
+    UNSUPPORTED,
+    /// OpenAI public platform
+    OPEN_AI,
+    /// Azure OpenAI platform
+    AZURE_OPEN_AI
+  };
   /// Sends a prompt to the input AI endpoint using the input model, api_key 
and
   /// optional params.
   static StringVal AiGenerateText(FunctionContext* ctx, const StringVal& 
endpoint,
@@ -58,14 +68,26 @@ class AiFunctions {
   /// Internal function which implements the logic of parsing user input and 
sending
   /// request to the external API endpoint. If 'dry_run' is set, the POST 
request is
   /// returned. 'dry_run' mode is used only for unit tests.
-  template <bool fastpath>
-  static StringVal AiGenerateTextInternal(FunctionContext* ctx, const 
StringVal& endpoint,
+  template <bool fastpath, AI_PLATFORM platform>
+  static StringVal AiGenerateTextInternal(
+      FunctionContext* ctx, const std::string_view& endpoint,
       const StringVal& prompt, const StringVal& model,
       const StringVal& api_key_jceks_secret, const StringVal& params, const 
bool dry_run);
+  /// Helper function for calling AiGenerateTextInternal with common code for 
both
+  /// fastpath and regular path.
+  template <bool fastpath>
+  static StringVal AiGenerateTextHelper(
+    FunctionContext* ctx, const StringVal& endpoint, const StringVal& prompt,
+    const StringVal& model, const StringVal& api_key_jceks_secret,
+    const StringVal& params);
   /// Internal helper function for parsing OPEN AI's API response. Input 
parameter is the
   /// json representation of the OPEN AI's API response.
   static std::string AiGenerateTextParseOpenAiResponse(
       const std::string_view& reponse);
+  /// Helper function for getting AI Platform from the endpoint
+  static AI_PLATFORM GetAiPlatformFromEndpoint(const std::string_view& 
endpoint);
+  /// Helper functions for deep copying error message
+  static StringVal copyErrorMessage(FunctionContext* ctx, const string& 
errorMsg);
 
   friend class ExprTest_AiFunctionsTest_Test;
 };
diff --git a/be/src/exprs/ai-functions.inline.h 
b/be/src/exprs/ai-functions.inline.h
index 9f143e2df..bd39a5002 100644
--- a/be/src/exprs/ai-functions.inline.h
+++ b/be/src/exprs/ai-functions.inline.h
@@ -42,55 +42,69 @@ DECLARE_int32(ai_connection_timeout_s);
 
 namespace impala {
 
-template <bool fastpath>
+#define RETURN_STRINGVAL_IF_ERROR(ctx, stmt)               \
+  do {                                                     \
+    const ::impala::Status& _status = (stmt);              \
+    if (UNLIKELY(!_status.ok())) {                         \
+      return copyErrorMessage(ctx, _status.msg().msg());   \
+    }                                                      \
+  } while (false)
+
+template<AiFunctions::AI_PLATFORM platform>
+Status getAuthorizationHeader(string& authHeader, const string& api_key) {
+  switch(platform) {
+    case AiFunctions::AI_PLATFORM::OPEN_AI:
+      authHeader = AiFunctions::OPEN_AI_REQUEST_AUTH_HEADER + api_key;
+      return Status::OK();
+    case AiFunctions::AI_PLATFORM::AZURE_OPEN_AI:
+      authHeader =  AiFunctions::AZURE_OPEN_AI_REQUEST_AUTH_HEADER + api_key;
+      return Status::OK();
+    default:
+      DCHECK(false) <<
+          "AiGenerateTextInternal should only be called for Supported 
Platforms";
+      return Status(AiFunctions::AI_GENERATE_TXT_UNSUPPORTED_ENDPOINT_ERROR);
+  }
+}
+
+template <bool fastpath, AiFunctions::AI_PLATFORM platform>
 StringVal AiFunctions::AiGenerateTextInternal(FunctionContext* ctx,
-    const StringVal& endpoint, const StringVal& prompt, const StringVal& model,
+    const std::string_view& endpoint_sv, const StringVal& prompt, const 
StringVal& model,
     const StringVal& api_key_jceks_secret, const StringVal& params, const bool 
dry_run) {
-  std::string_view endpoint_sv(FLAGS_ai_endpoint);
-  // endpoint validation
-  if (!fastpath && endpoint.ptr != nullptr && endpoint.len != 0) {
-    endpoint_sv = std::string_view(reinterpret_cast<char*>(endpoint.ptr), 
endpoint.len);
-    // Simple validation for endpoint. It should start with https://
-    if (!is_api_endpoint_valid(endpoint_sv)) {
-      LOG(ERROR) << "AI Generate Text: \ninvalid protocol: " << endpoint_sv;
-      return StringVal(AI_GENERATE_TXT_INVALID_PROTOCOL_ERROR.c_str());
-    }
-    // Only OpenAI endpoints are supported.
-    if (!is_api_endpoint_supported(endpoint_sv)) {
-      LOG(ERROR) << "AI Generate Text: \nunsupported endpoint: " << 
endpoint_sv;
-      return StringVal(AI_GENERATE_TXT_UNSUPPORTED_ENDPOINT_ERROR.c_str());
-    }
-  }
   // Generate the header for the POST request
   vector<string> headers;
   headers.emplace_back(OPEN_AI_REQUEST_FIELD_CONTENT_TYPE_HEADER);
+  string authHeader;
   if (!fastpath && api_key_jceks_secret.ptr != nullptr && 
api_key_jceks_secret.len != 0) {
     string api_key;
     string api_key_secret(
         reinterpret_cast<char*>(api_key_jceks_secret.ptr), 
api_key_jceks_secret.len);
-    Status status = ExecEnv::GetInstance()->frontend()->GetSecretFromKeyStore(
-        api_key_secret, &api_key);
-    if (!status.ok()) {
-      return StringVal::CopyFrom(ctx,
-          reinterpret_cast<const uint8_t*>(status.msg().msg().c_str()),
-          status.msg().msg().length());
-    }
-    headers.emplace_back("Authorization: Bearer " + api_key);
+    RETURN_STRINGVAL_IF_ERROR(ctx,
+        ExecEnv::GetInstance()->frontend()->GetSecretFromKeyStore(
+            api_key_secret, &api_key));
+    RETURN_STRINGVAL_IF_ERROR(ctx,
+        getAuthorizationHeader<platform>(authHeader, api_key));
   } else {
-    headers.emplace_back("Authorization: Bearer " + ai_api_key_);
+    RETURN_STRINGVAL_IF_ERROR(ctx,
+        getAuthorizationHeader<platform>(authHeader, ai_api_key_));
   }
+  headers.emplace_back(authHeader);
   // Generate the payload for the POST request
   Document payload;
   payload.SetObject();
   Document::AllocatorType& payload_allocator = payload.GetAllocator();
-  if (!fastpath && model.ptr != nullptr && model.len != 0) {
-    payload.AddMember("model",
-        rapidjson::StringRef(reinterpret_cast<char*>(model.ptr), model.len),
-        payload_allocator);
-  } else {
-    payload.AddMember("model",
-        rapidjson::StringRef(FLAGS_ai_model.c_str(), FLAGS_ai_model.length()),
-        payload_allocator);
+  // Azure Open AI endpoint doesn't expect model as a separate param since it's
+  // embedded in the endpoint. The 'deployment_name' below maps to a model.
+  // 
https://<resource_name>.openai.azure.com/openai/deployments/<deployment_name>/..
+  if (platform != AI_PLATFORM::AZURE_OPEN_AI) {
+    if (!fastpath && model.ptr != nullptr && model.len != 0) {
+      payload.AddMember("model",
+          rapidjson::StringRef(reinterpret_cast<char*>(model.ptr), model.len),
+          payload_allocator);
+    } else {
+      payload.AddMember("model",
+          rapidjson::StringRef(FLAGS_ai_model.c_str(), 
FLAGS_ai_model.length()),
+          payload_allocator);
+    }
   }
   Value message_array(rapidjson::kArrayType);
   Value message(rapidjson::kObjectType);
@@ -169,11 +183,7 @@ StringVal 
AiFunctions::AiGenerateTextInternal(FunctionContext* ctx,
     status = curl.PostToURL(endpoint_str, payload_str, &resp, headers);
   }
   VLOG(2) << "AI Generate Text: \noriginal response: " << resp.ToString();
-  if (!status.ok()) {
-    string msg = status.ToString();
-    return StringVal::CopyFrom(
-        ctx, reinterpret_cast<const uint8_t*>(msg.c_str()), msg.size());
-  }
+  if (UNLIKELY(!status.ok())) return copyErrorMessage(ctx, status.ToString());
   // Parse the JSON response string
   std::string response = AiGenerateTextParseOpenAiResponse(
       std::string_view(reinterpret_cast<char*>(resp.data()), resp.size()));
diff --git a/be/src/exprs/expr-test.cc b/be/src/exprs/expr-test.cc
index 61e53933f..e57dbeeb7 100644
--- a/be/src/exprs/expr-test.cc
+++ b/be/src/exprs/expr-test.cc
@@ -11227,7 +11227,10 @@ TEST_P(ExprTest, AiFunctionsTest) {
   string secret_key("do_not_share");
   AiFunctions::set_api_key(secret_key);
   // valid endpoint
-  StringVal openai_endpoint("https://openai.azure.com";);
+  std::string_view 
openai_endpoint("https://api.openai.com/v1/chat/completions";);
+  std::string_view azure_openai_endpoint(
+      "https://resource.openai.azure.com/openai/deployments/";
+      "deployment/completions?api-version=2024-02-01");
   // empty jceks secret key
   StringVal jceks_secret("");
   // dummy model.
@@ -11239,9 +11242,19 @@ TEST_P(ExprTest, AiFunctionsTest) {
   // dry_run to receive HTTP request header and body
   bool dry_run = true;
 
+  // Test GetAiPlatformFromEndpoint
+  EXPECT_EQ(AiFunctions::AI_PLATFORM::OPEN_AI,
+      AiFunctions::GetAiPlatformFromEndpoint(openai_endpoint));
+  EXPECT_EQ(AiFunctions::AI_PLATFORM::AZURE_OPEN_AI,
+      AiFunctions::GetAiPlatformFromEndpoint(azure_openai_endpoint));
+  EXPECT_EQ(AiFunctions::AI_PLATFORM::UNSUPPORTED,
+      AiFunctions::GetAiPlatformFromEndpoint("https://qwerty.com";));
+
   // Test fastpath
-  StringVal result = AiFunctions::AiGenerateTextInternal<true>(ctx, 
StringVal::null(),
-      prompt, StringVal::null(), StringVal::null(), StringVal::null(), 
dry_run);
+  StringVal result =
+    AiFunctions::AiGenerateTextInternal<true, 
AiFunctions::AI_PLATFORM::OPEN_AI>(
+        ctx, FLAGS_ai_endpoint, prompt, StringVal::null(), StringVal::null(),
+        StringVal::null(), dry_run);
   EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
       string("https://api.openai.com/v1/chat/completions";
              "\nContent-Type: application/json"
@@ -11249,22 +11262,34 @@ TEST_P(ExprTest, AiFunctionsTest) {
              
"\n{\"model\":\"gpt-4\",\"messages\":[{\"role\":\"user\",\"content\":"
              "\"hello!\"}]}"));
 
+  result =
+    AiFunctions::AiGenerateTextInternal<true, 
AiFunctions::AI_PLATFORM::AZURE_OPEN_AI>(
+        ctx, azure_openai_endpoint, prompt, StringVal::null(), 
StringVal::null(),
+        StringVal::null(), dry_run);
+  EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
+      string("https://resource.openai.azure.com/openai/deployments/";
+             "deployment/completions?api-version=2024-02-01"
+             "\nContent-Type: application/json"
+             "\napi-key: do_not_share"
+             "\n{\"messages\":[{\"role\":\"user\",\"content\":"
+             "\"hello!\"}]}"));
+
   // Test endpoints.
   // endpoints must begin with https.
-  result = AiFunctions::AiGenerateTextInternal<false>(
-      ctx, StringVal("http://ai.com";), prompt, model, jceks_secret, 
json_params, dry_run);
+  result = AiFunctions::AiGenerateText(
+      ctx, StringVal("http://ai.com";), prompt, model, jceks_secret, 
json_params);
   EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
       AiFunctions::AI_GENERATE_TXT_INVALID_PROTOCOL_ERROR);
   // only OpenAI endpoints are supported.
-  result = AiFunctions::AiGenerateTextInternal<false>(ctx, 
StringVal("https://ai.com";),
-      prompt, model, jceks_secret, json_params, dry_run);
+  result = AiFunctions::AiGenerateText(
+      ctx, "https://ai.com";, prompt, model, jceks_secret, json_params);
   EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
       AiFunctions::AI_GENERATE_TXT_UNSUPPORTED_ENDPOINT_ERROR);
   // valid request using OpenAI endpoint.
-  result = AiFunctions::AiGenerateTextInternal<false>(
+  result = AiFunctions::AiGenerateTextInternal<false, 
AiFunctions::AI_PLATFORM::OPEN_AI>(
       ctx, openai_endpoint, prompt, model, jceks_secret, json_params, dry_run);
   EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
-      string("https://openai.azure.com";
+      string("https://api.openai.com/v1/chat/completions";
              "\nContent-Type: application/json"
              "\nAuthorization: Bearer do_not_share"
              
"\n{\"model\":\"bot\",\"messages\":[{\"role\":\"user\",\"content\":\"hello!"
@@ -11273,21 +11298,27 @@ TEST_P(ExprTest, AiFunctionsTest) {
   // Test prompt.
   // prompt cannot be empty.
   StringVal invalid_prompt("");
-  result = AiFunctions::AiGenerateTextInternal<false>(
+  result = AiFunctions::AiGenerateTextInternal<false, 
AiFunctions::AI_PLATFORM::OPEN_AI>(
       ctx, openai_endpoint, invalid_prompt, model, jceks_secret, json_params, 
dry_run);
+  EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
+      AiFunctions::AI_GENERATE_TXT_INVALID_PROMPT_ERROR);
+  result = AiFunctions::AiGenerateTextDefault(ctx, invalid_prompt);
   EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
       AiFunctions::AI_GENERATE_TXT_INVALID_PROMPT_ERROR);
   // prompt cannot be null.
   invalid_prompt = StringVal::null();
-  result = AiFunctions::AiGenerateTextInternal<false>(
+  result = AiFunctions::AiGenerateTextInternal<false, 
AiFunctions::AI_PLATFORM::OPEN_AI>(
       ctx, openai_endpoint, invalid_prompt, model, jceks_secret, json_params, 
dry_run);
   EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
       AiFunctions::AI_GENERATE_TXT_INVALID_PROMPT_ERROR);
+  result = AiFunctions::AiGenerateTextDefault(ctx, invalid_prompt);
+  EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
+      AiFunctions::AI_GENERATE_TXT_INVALID_PROMPT_ERROR);
 
   // Test override/additional params
   // invalid json results in error.
   StringVal invalid_json_params("{\"temperature\": 0.49, \"stop\": 
[\"*\",::,]}");
-  result = AiFunctions::AiGenerateTextInternal<false>(
+  result = AiFunctions::AiGenerateTextInternal<false, 
AiFunctions::AI_PLATFORM::OPEN_AI>(
       ctx, openai_endpoint, prompt, model, jceks_secret, invalid_json_params, 
dry_run);
   EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
       AiFunctions::AI_GENERATE_TXT_JSON_PARSE_ERROR);
@@ -11295,10 +11326,10 @@ TEST_P(ExprTest, AiFunctionsTest) {
   // like 'temperature' and 'stop'.
   StringVal valid_json_params(
       "{\"model\": \"gpt\", \"temperature\": 0.49, \"stop\": [\"*\", \"%\"]}");
-  result = AiFunctions::AiGenerateTextInternal<false>(
+  result = AiFunctions::AiGenerateTextInternal<false, 
AiFunctions::AI_PLATFORM::OPEN_AI>(
       ctx, openai_endpoint, prompt, model, jceks_secret, valid_json_params, 
dry_run);
   EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
-      string("https://openai.azure.com";
+      string("https://api.openai.com/v1/chat/completions";
              "\nContent-Type: application/json"
              "\nAuthorization: Bearer do_not_share"
              
"\n{\"model\":\"gpt\",\"messages\":[{\"role\":\"user\",\"content\":\"hello!"
@@ -11306,44 +11337,45 @@ TEST_P(ExprTest, AiFunctionsTest) {
   // messages cannot be overriden, as they we constructed from the prompt.
   StringVal forbidden_msg_override(
       "{\"messages\": [{\"role\":\"system\",\"content\":\"howdy!\"}]}");
-  result = AiFunctions::AiGenerateTextInternal<false>(
+  result = AiFunctions::AiGenerateTextInternal<false, 
AiFunctions::AI_PLATFORM::OPEN_AI>(
       ctx, openai_endpoint, prompt, model, jceks_secret, 
forbidden_msg_override, dry_run);
   EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
       AiFunctions::AI_GENERATE_TXT_MSG_OVERRIDE_FORBIDDEN_ERROR);
   // 'n != 1' cannot be overriden as additional params
   StringVal forbidden_n_value("{\"n\": 2}");
-  result = AiFunctions::AiGenerateTextInternal<false>(
+  result = AiFunctions::AiGenerateTextInternal<false, 
AiFunctions::AI_PLATFORM::OPEN_AI>(
       ctx, openai_endpoint, prompt, model, jceks_secret, forbidden_n_value, 
dry_run);
   EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
       AiFunctions::AI_GENERATE_TXT_N_OVERRIDE_FORBIDDEN_ERROR);
   // non integer value of 'n' cannot be overriden as additional params
   StringVal forbidden_n_type("{\"n\": \"1\"}");
-  result = AiFunctions::AiGenerateTextInternal<false>(
+  result = AiFunctions::AiGenerateTextInternal<false, 
AiFunctions::AI_PLATFORM::OPEN_AI>(
       ctx, openai_endpoint, prompt, model, jceks_secret, forbidden_n_type, 
dry_run);
   EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
       AiFunctions::AI_GENERATE_TXT_N_OVERRIDE_FORBIDDEN_ERROR);
   // accept 'n=1' override as additional params
   StringVal allowed_n_override("{\"n\": 1}");
-  result = AiFunctions::AiGenerateTextInternal<false>(
+  result = AiFunctions::AiGenerateTextInternal<false, 
AiFunctions::AI_PLATFORM::OPEN_AI>(
       ctx, openai_endpoint, prompt, model, jceks_secret, allowed_n_override, 
dry_run);
   EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
-      string("https://openai.azure.com";
+      string("https://api.openai.com/v1/chat/completions";
              "\nContent-Type: application/json"
              "\nAuthorization: Bearer do_not_share"
              
"\n{\"model\":\"bot\",\"messages\":[{\"role\":\"user\",\"content\":\"hello!"
              "\"}],\"n\":1}"));
 
   // Test flag file options are used when input is empty/null
-  result = AiFunctions::AiGenerateTextInternal<false>(ctx, StringVal::null(), 
prompt,
-      StringVal::null(), jceks_secret, json_params, dry_run);
+  result = AiFunctions::AiGenerateTextInternal<false, 
AiFunctions::AI_PLATFORM::OPEN_AI>(
+      ctx, FLAGS_ai_endpoint, prompt, StringVal::null(), jceks_secret, 
json_params,
+      dry_run);
   EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
       string("https://api.openai.com/v1/chat/completions";
              "\nContent-Type: application/json"
              "\nAuthorization: Bearer do_not_share"
              
"\n{\"model\":\"gpt-4\",\"messages\":[{\"role\":\"user\",\"content\":"
              "\"hello!\"}]}"));
-  result = AiFunctions::AiGenerateTextInternal<false>(
-      ctx, StringVal(""), prompt, StringVal(""), jceks_secret, json_params, 
dry_run);
+  result = AiFunctions::AiGenerateTextInternal<false, 
AiFunctions::AI_PLATFORM::OPEN_AI>(
+      ctx, FLAGS_ai_endpoint, prompt, StringVal(""), jceks_secret, 
json_params, dry_run);
   EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
       string("https://api.openai.com/v1/chat/completions";
              "\nContent-Type: application/json"

Reply via email to