This is an automated email from the ASF dual-hosted git repository.
arawat pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/impala.git
The following commit(s) were added to refs/heads/master by this push:
new 9837637d9 IMPALA-12920: Support ai_generate_text built-in function for
OpenAI's chat completion API
9837637d9 is described below
commit 9837637d9342a49288a13a421d4e749818da1432
Author: Yida Wu <[email protected]>
AuthorDate: Wed Sep 13 10:27:29 2023 -0700
IMPALA-12920: Support ai_generate_text built-in function for OpenAI's chat
completion API
Added support for following built-in functions:
- ai_generate_text_default(prompt)
- ai_generate_text(ai_endpoint, prompt, ai_model,
ai_api_key_jceks_secret, additional_params)
'ai_endpoint', 'ai_model' and 'ai_api_key_jceks_secret' are flagfile
options. 'ai_generate_text_default(prompt)' syntax expects all these
to be set to proper values. The other syntax, will try to use the
provided input parameter values, but fallback to instance level values
if the inputs are NULL or empty.
Only public OpenAI (api.openai.com) and Azure OpenAI (openai.azure.com)
API endpoints are currently supported.
Exposed these functions in FunctionContext so that they can also be
called from UDFs:
- ai_generate_text_default(context, model)
- ai_generate_text(context, ai_endpoint, prompt, ai_model,
ai_api_key_jceks_secret, additional_params)
Testing:
- Added unit tests for AiGenerateTextInternal function
- Added fe test for JniFrontend::getSecretFromKeyStore
- Ran manual tests to make sure Impala can talk with OpenAI LLMs using
'ai_generate_text' built-in function. Example sql:
select ai_generate_text("https://api.openai.com/v1/chat/completions",
"hello", "gpt-3.5-turbo", "open-ai-key",
'{"temperature": 0.9, "model": "gpt-4"}')
- Tested using standalone UDF SDK and made sure that the UDFs can invoke
BuiltInFunctions (ai_generate_text and ai_generate_text_default)
Change-Id: Id4446957f6030bab1f985fdd69185c3da07d7c4b
Reviewed-on: http://gerrit.cloudera.org:8080/21168
Reviewed-by: Impala Public Jenkins <[email protected]>
Tested-by: Impala Public Jenkins <[email protected]>
---
be/src/exprs/CMakeLists.txt | 1 +
be/src/exprs/ai-functions-ir.cc | 140 ++++++++++++++++
be/src/exprs/ai-functions.h | 73 ++++++++
be/src/exprs/ai-functions.inline.h | 184 +++++++++++++++++++++
be/src/exprs/expr-test.cc | 182 ++++++++++++++++++++
be/src/exprs/scalar-expr-evaluator.cc | 2 +
be/src/runtime/exec-env.cc | 13 ++
be/src/service/frontend.cc | 22 ++-
be/src/service/frontend.h | 5 +
be/src/udf/udf-internal.h | 5 +-
be/src/udf/udf.cc | 26 ++-
be/src/udf/udf.h | 17 ++
be/src/udf_samples/udf-sample.cc | 28 ++++
be/src/udf_samples/udf-sample.h | 3 +
be/src/util/jni-util.h | 10 ++
bin/load-data.py | 15 ++
bin/rat_exclude_files.txt | 1 +
common/function-registry/impala_functions.py | 4 +
.../org/apache/impala/service/JniFrontend.java | 28 ++++
.../org/apache/impala/service/JniFrontendTest.java | 25 +++
.../common/etc/hadoop/conf/core-site.xml.py | 6 +
testdata/jceks/.gitkeep | 0
22 files changed, 784 insertions(+), 6 deletions(-)
diff --git a/be/src/exprs/CMakeLists.txt b/be/src/exprs/CMakeLists.txt
index 782f652e7..df5ef65f3 100644
--- a/be/src/exprs/CMakeLists.txt
+++ b/be/src/exprs/CMakeLists.txt
@@ -25,6 +25,7 @@ set(EXECUTABLE_OUTPUT_PATH
"${BUILD_OUTPUT_ROOT_DIRECTORY}/exprs")
set(MURMURHASH_SRC_DIR "${CMAKE_SOURCE_DIR}/be/src/thirdparty/murmurhash")
add_library(ExprsIr
+ ai-functions-ir.cc
agg-fn-evaluator-ir.cc
aggregate-functions-ir.cc
bit-byte-functions-ir.cc
diff --git a/be/src/exprs/ai-functions-ir.cc b/be/src/exprs/ai-functions-ir.cc
new file mode 100644
index 000000000..e482cb688
--- /dev/null
+++ b/be/src/exprs/ai-functions-ir.cc
@@ -0,0 +1,140 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// The functions in this file are specifically not cross-compiled to IR
because there
+// is no signifcant performance benefit to be gained.
+
+#include <gutil/strings/util.h>
+
+#include "exprs/ai-functions.inline.h"
+
+using namespace impala_udf;
+
+DEFINE_string(ai_endpoint, "https://api.openai.com/v1/chat/completions",
+ "The default API endpoint for an external AI engine.");
+DEFINE_validator(ai_endpoint, [](const char* name, const string& endpoint) {
+ return (impala::AiFunctions::is_api_endpoint_valid(endpoint) &&
+ impala::AiFunctions::is_api_endpoint_supported(endpoint));
+});
+
+DEFINE_string(ai_model, "gpt-4", "The default AI model used by an external AI
engine.");
+
+DEFINE_string(ai_api_key_jceks_secret, "",
+ "The jceks secret key used for extracting the api key from configured
keystores. "
+ "'hadoop.security.credential.provider.path' in core-site must be
configured to "
+ "include the keystore storing the corresponding secret.");
+
+DEFINE_int32(ai_connection_timeout_s, 10,
+ "(Advanced) The time in seconds for connection timed out when
communicating with an "
+ "external AI engine");
+TAG_FLAG(ai_api_key_jceks_secret, sensitive);
+
+namespace impala {
+
+// static class members
+const string AiFunctions::AI_GENERATE_TXT_JSON_PARSE_ERROR = "Invalid Json";
+const string AiFunctions::AI_GENERATE_TXT_INVALID_PROTOCOL_ERROR =
+ "Invalid Protocol, use https";
+const string AiFunctions::AI_GENERATE_TXT_UNSUPPORTED_ENDPOINT_ERROR =
+ "Unsupported Endpoint";
+const string AiFunctions::AI_GENERATE_TXT_INVALID_PROMPT_ERROR =
+ "Invalid Prompt, cannot be null or empty";
+const string AiFunctions::AI_GENERATE_TXT_MSG_OVERRIDE_FORBIDDEN_ERROR =
+ "Invalid override, 'messages' cannot be overriden";
+const string AiFunctions::AI_GENERATE_TXT_N_OVERRIDE_FORBIDDEN_ERROR =
+ "Invalid override, 'n' must be of integer type and have value 1";
+string AiFunctions::ai_api_key_;
+const char* AiFunctions::OPEN_AI_REQUEST_FIELD_CONTENT_TYPE_HEADER =
+ "Content-Type: application/json";
+
+// other constants
+static const StringVal NULL_STRINGVAL = StringVal::null();
+static const char* AI_API_ENDPOINT_PREFIX = "https://";
+static const char* OPEN_AI_AZURE_ENDPOINT = "openai.azure.com";
+static const char* OPEN_AI_PUBLIC_ENDPOINT = "api.openai.com";
+// OPEN AI specific constants
+static const char* OPEN_AI_RESPONSE_FIELD_CHOICES = "choices";
+static const char* OPEN_AI_RESPONSE_FIELD_MESSAGE = "message";
+static const char* OPEN_AI_RESPONSE_FIELD_CONTENT = "content";
+
+bool AiFunctions::is_api_endpoint_valid(const std::string_view& endpoint) {
+ // Simple validation for endpoint. It should start with https://
+ return (strncaseprefix(endpoint.data(), endpoint.size(),
AI_API_ENDPOINT_PREFIX,
+ sizeof(AI_API_ENDPOINT_PREFIX))
+ != nullptr);
+}
+
+bool AiFunctions::is_api_endpoint_supported(const std::string_view& endpoint) {
+ // Only OpenAI endpoints are supported.
+ return (
+ gstrncasestr(endpoint.data(), OPEN_AI_AZURE_ENDPOINT, endpoint.size())
!= nullptr ||
+ gstrncasestr(endpoint.data(), OPEN_AI_PUBLIC_ENDPOINT, endpoint.size())
!= nullptr);
+}
+
+std::string_view AiFunctions::AiGenerateTextParseOpenAiResponse(
+ const std::string_view& response) {
+ rapidjson::Document document;
+ document.Parse(response.data(), response.size());
+ // Check for parse errors
+ if (document.HasParseError()) {
+ LOG(WARNING) << AI_GENERATE_TXT_JSON_PARSE_ERROR << ": " << response;
+ return AI_GENERATE_TXT_JSON_PARSE_ERROR;
+ }
+ // Check if the "choices" array exists and is not empty
+ if (!document.HasMember(OPEN_AI_RESPONSE_FIELD_CHOICES)
+ || !document[OPEN_AI_RESPONSE_FIELD_CHOICES].IsArray()
+ || document[OPEN_AI_RESPONSE_FIELD_CHOICES].Empty()) {
+ LOG(WARNING) << AI_GENERATE_TXT_JSON_PARSE_ERROR << ": " << response;
+ return AI_GENERATE_TXT_JSON_PARSE_ERROR;
+ }
+
+ // Access the first element of the "choices" array
+ const rapidjson::Value& firstChoice =
document[OPEN_AI_RESPONSE_FIELD_CHOICES][0];
+
+ // Check if the "message" object exists
+ if (!firstChoice.HasMember(OPEN_AI_RESPONSE_FIELD_MESSAGE)
+ || !firstChoice[OPEN_AI_RESPONSE_FIELD_MESSAGE].IsObject()) {
+ LOG(WARNING) << AI_GENERATE_TXT_JSON_PARSE_ERROR << ": " << response;
+ return AI_GENERATE_TXT_JSON_PARSE_ERROR;
+ }
+
+ // Access the "content" field within "message"
+ const rapidjson::Value& message =
firstChoice[OPEN_AI_RESPONSE_FIELD_MESSAGE];
+ if (!message.HasMember(OPEN_AI_RESPONSE_FIELD_CONTENT)
+ || !message[OPEN_AI_RESPONSE_FIELD_CONTENT].IsString()) {
+ LOG(WARNING) << AI_GENERATE_TXT_JSON_PARSE_ERROR << ": " << response;
+ return AI_GENERATE_TXT_JSON_PARSE_ERROR;
+ }
+
+ const rapidjson::Value& result = message[OPEN_AI_RESPONSE_FIELD_CONTENT];
+ return std::string_view(result.GetString(), result.GetStringLength());
+}
+
+StringVal AiFunctions::AiGenerateText(FunctionContext* ctx, const StringVal&
endpoint,
+ const StringVal& prompt, const StringVal& model,
+ const StringVal& api_key_jceks_secret, const StringVal& params) {
+ return AiGenerateTextInternal<false>(
+ ctx, endpoint, prompt, model, api_key_jceks_secret, params, false);
+}
+
+StringVal AiFunctions::AiGenerateTextDefault(
+ FunctionContext* ctx, const StringVal& prompt) {
+ return AiGenerateTextInternal<true>(
+ ctx, NULL_STRINGVAL, prompt, NULL_STRINGVAL, NULL_STRINGVAL,
NULL_STRINGVAL, false);
+}
+
+} // namespace impala
diff --git a/be/src/exprs/ai-functions.h b/be/src/exprs/ai-functions.h
new file mode 100644
index 000000000..c1d2e635e
--- /dev/null
+++ b/be/src/exprs/ai-functions.h
@@ -0,0 +1,73 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <string_view>
+
+#include "udf/udf.h"
+
+using namespace impala_udf;
+
+namespace impala {
+
+using impala_udf::FunctionContext;
+using impala_udf::StringVal;
+
+class AiFunctions {
+ public:
+ static const string AI_GENERATE_TXT_JSON_PARSE_ERROR;
+ static const string AI_GENERATE_TXT_INVALID_PROTOCOL_ERROR;
+ static const string AI_GENERATE_TXT_UNSUPPORTED_ENDPOINT_ERROR;
+ static const string AI_GENERATE_TXT_INVALID_PROMPT_ERROR;
+ static const string AI_GENERATE_TXT_MSG_OVERRIDE_FORBIDDEN_ERROR;
+ static const string AI_GENERATE_TXT_N_OVERRIDE_FORBIDDEN_ERROR;
+ static const char* OPEN_AI_REQUEST_FIELD_CONTENT_TYPE_HEADER;
+ /// Sends a prompt to the input AI endpoint using the input model, api_key
and
+ /// optional params.
+ static StringVal AiGenerateText(FunctionContext* ctx, const StringVal&
endpoint,
+ const StringVal& prompt, const StringVal& model,
+ const StringVal& api_key_jceks_secret, const StringVal& params);
+ /// Sends a prompt to the default endpoint and uses the default model,
default
+ /// api-key and default params.
+ static StringVal AiGenerateTextDefault(FunctionContext* ctx, const
StringVal& prompt);
+ /// Set the ai_api_key_ member.
+ static void set_api_key(string& api_key) { ai_api_key_ = api_key; }
+ /// Validate api end point.
+ static bool is_api_endpoint_valid(const std::string_view& endpoint);
+ /// Check if endpoint is supported
+ static bool is_api_endpoint_supported(const std::string_view& endpoint);
+
+ private:
+ /// The default api_key used for communicating with external APIs.
+ static std::string ai_api_key_;
+ /// Internal function which implements the logic of parsing user input and
sending
+ /// request to the external API endpoint. If 'dry_run' is set, the POST
request is
+ /// returned. 'dry_run' mode is used only for unit tests.
+ template <bool fastpath>
+ static StringVal AiGenerateTextInternal(FunctionContext* ctx, const
StringVal& endpoint,
+ const StringVal& prompt, const StringVal& model,
+ const StringVal& api_key_jceks_secret, const StringVal& params, const
bool dry_run);
+ /// Internal helper function for parsing OPEN AI's API response. Input
parameter is the
+ /// json representation of the OPEN AI's API response.
+ static std::string_view AiGenerateTextParseOpenAiResponse(
+ const std::string_view& reponse);
+
+ friend class ExprTest_AiFunctionsTest_Test;
+};
+
+} // namespace impala
diff --git a/be/src/exprs/ai-functions.inline.h
b/be/src/exprs/ai-functions.inline.h
new file mode 100644
index 000000000..7f7bcfd92
--- /dev/null
+++ b/be/src/exprs/ai-functions.inline.h
@@ -0,0 +1,184 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <rapidjson/document.h>
+#include <rapidjson/error/en.h>
+#include <rapidjson/stringbuffer.h>
+#include <rapidjson/writer.h>
+
+#include "common/compiler-util.h"
+#include "exprs/ai-functions.h"
+#include "kudu/util/curl_util.h"
+#include "kudu/util/faststring.h"
+#include "kudu/util/flag_tags.h"
+#include "kudu/util/monotime.h"
+#include "kudu/util/status.h"
+#include "runtime/exec-env.h"
+#include "service/frontend.h"
+
+using namespace rapidjson;
+using namespace impala_udf;
+
+DECLARE_string(ai_endpoint);
+DECLARE_string(ai_model);
+DECLARE_string(ai_api_key_jceks_secret);
+DECLARE_int32(ai_connection_timeout_s);
+
+namespace impala {
+
+template <bool fastpath>
+StringVal AiFunctions::AiGenerateTextInternal(FunctionContext* ctx,
+ const StringVal& endpoint, const StringVal& prompt, const StringVal& model,
+ const StringVal& api_key_jceks_secret, const StringVal& params, const bool
dry_run) {
+ std::string_view endpoint_sv(FLAGS_ai_endpoint);
+ // endpoint validation
+ if (!fastpath && endpoint.ptr != nullptr && endpoint.len != 0) {
+ endpoint_sv = std::string_view(reinterpret_cast<char*>(endpoint.ptr),
endpoint.len);
+ // Simple validation for endpoint. It should start with https://
+ if (!is_api_endpoint_valid(endpoint_sv)) {
+ LOG(ERROR) << "AI Generate Text: \ninvalid protocol: " << endpoint_sv;
+ return StringVal(AI_GENERATE_TXT_INVALID_PROTOCOL_ERROR.c_str());
+ }
+ // Only OpenAI endpoints are supported.
+ if (!is_api_endpoint_supported(endpoint_sv)) {
+ LOG(ERROR) << "AI Generate Text: \nunsupported endpoint: " <<
endpoint_sv;
+ return StringVal(AI_GENERATE_TXT_UNSUPPORTED_ENDPOINT_ERROR.c_str());
+ }
+ }
+ // Generate the header for the POST request
+ vector<string> headers;
+ headers.emplace_back(OPEN_AI_REQUEST_FIELD_CONTENT_TYPE_HEADER);
+ if (!fastpath && api_key_jceks_secret.ptr != nullptr &&
api_key_jceks_secret.len != 0) {
+ string api_key;
+ string api_key_secret(
+ reinterpret_cast<char*>(api_key_jceks_secret.ptr),
api_key_jceks_secret.len);
+ Status status = ExecEnv::GetInstance()->frontend()->GetSecretFromKeyStore(
+ api_key_secret, &api_key);
+ if (!status.ok()) {
+ return StringVal::CopyFrom(ctx,
+ reinterpret_cast<const uint8_t*>(status.msg().msg().c_str()),
+ status.msg().msg().length());
+ }
+ headers.emplace_back("Authorization: Bearer " + api_key);
+ } else {
+ headers.emplace_back("Authorization: Bearer " + ai_api_key_);
+ }
+ // Generate the payload for the POST request
+ Document payload;
+ payload.SetObject();
+ Document::AllocatorType& payload_allocator = payload.GetAllocator();
+ if (!fastpath && model.ptr != nullptr && model.len != 0) {
+ payload.AddMember("model",
+ rapidjson::StringRef(reinterpret_cast<char*>(model.ptr), model.len),
+ payload_allocator);
+ } else {
+ payload.AddMember("model",
+ rapidjson::StringRef(FLAGS_ai_model.c_str(), FLAGS_ai_model.length()),
+ payload_allocator);
+ }
+ Value message_array(rapidjson::kArrayType);
+ Value message(rapidjson::kObjectType);
+ message.AddMember("role", "user", payload_allocator);
+ if (prompt.ptr == nullptr || prompt.len == 0) {
+ return StringVal(AI_GENERATE_TXT_INVALID_PROMPT_ERROR.c_str());
+ }
+ message.AddMember("content",
+ rapidjson::StringRef(reinterpret_cast<char*>(prompt.ptr), prompt.len),
+ payload_allocator);
+ message_array.PushBack(message, payload_allocator);
+ payload.AddMember("messages", message_array, payload_allocator);
+ // Override additional params
+ if (!fastpath && params.ptr != nullptr && params.len != 0) {
+ Document overrides;
+ overrides.Parse(reinterpret_cast<char*>(params.ptr), params.len);
+ if (overrides.HasParseError()) {
+ LOG(WARNING) << AI_GENERATE_TXT_JSON_PARSE_ERROR << ": error code "
+ << overrides.GetParseError() << ", offset input "
+ << overrides.GetErrorOffset();
+ return StringVal(AI_GENERATE_TXT_JSON_PARSE_ERROR.c_str());
+ }
+ for (auto& m : overrides.GetObject()) {
+ if (payload.HasMember(m.name.GetString())) {
+ if (m.name == "messages") {
+ LOG(WARNING)
+ << AI_GENERATE_TXT_JSON_PARSE_ERROR
+ << ": 'messages' is constructed from 'prompt', cannot be
overridden";
+ return
StringVal(AI_GENERATE_TXT_MSG_OVERRIDE_FORBIDDEN_ERROR.c_str());
+ } else {
+ payload[m.name.GetString()] = m.value;
+ }
+ } else {
+ if ((m.name == "n") && !(m.value.IsInt() && m.value.GetInt() == 1)) {
+ LOG(WARNING)
+ << AI_GENERATE_TXT_JSON_PARSE_ERROR
+ << ": 'n' must be of integer type and have value 1";
+ return StringVal(AI_GENERATE_TXT_N_OVERRIDE_FORBIDDEN_ERROR.c_str());
+ }
+ payload.AddMember(m.name, m.value, payload_allocator);
+ }
+ }
+ }
+ // Convert payload into string for POST request
+ StringBuffer buffer;
+ Writer<StringBuffer> writer(buffer);
+ payload.Accept(writer);
+ string payload_str(buffer.GetString(), buffer.GetSize());
+ VLOG(2) << "AI Generate Text: \nendpoint: " << endpoint_sv
+ << " \npayload: " << payload_str;
+ if (UNLIKELY(dry_run)) {
+ std::stringstream post_request;
+ post_request << endpoint_sv;
+ for (auto& header : headers) {
+ post_request << "\n" << header;
+ }
+ post_request << "\n" << payload_str;
+ return StringVal::CopyFrom(ctx,
+ reinterpret_cast<const uint8_t*>(post_request.str().data()),
+ post_request.str().length());
+ }
+ // Send request to external AI API endpoint
+ kudu::EasyCurl curl;
+
curl.set_timeout(kudu::MonoDelta::FromSeconds(FLAGS_ai_connection_timeout_s));
+ curl.set_fail_on_http_error(true);
+ kudu::faststring resp;
+ kudu::Status status;
+ if (fastpath) {
+ DCHECK_EQ(std::string_view(FLAGS_ai_endpoint), endpoint_sv);
+ status = curl.PostToURL(FLAGS_ai_endpoint, payload_str, &resp, headers);
+ } else {
+ std::string endpoint_str{endpoint_sv};
+ status = curl.PostToURL(endpoint_str, payload_str, &resp, headers);
+ }
+ VLOG(2) << "AI Generate Text: \noriginal response: " << resp.ToString();
+ if (!status.ok()) {
+ string msg = status.ToString();
+ return StringVal::CopyFrom(
+ ctx, reinterpret_cast<const uint8_t*>(msg.c_str()), msg.size());
+ }
+ // Parse the JSON response string
+ std::string_view response = AiGenerateTextParseOpenAiResponse(
+ std::string_view(reinterpret_cast<char*>(resp.data()), resp.size()));
+ VLOG(2) << "AI Generate Text: \nresponse: " << response;
+ StringVal result(ctx, response.length());
+ if (UNLIKELY(result.is_null)) return StringVal::null();
+ memcpy(result.ptr, response.data(), response.length());
+ return result;
+}
+
+} // namespace impala
diff --git a/be/src/exprs/expr-test.cc b/be/src/exprs/expr-test.cc
index d77d02f9d..a05615057 100644
--- a/be/src/exprs/expr-test.cc
+++ b/be/src/exprs/expr-test.cc
@@ -34,6 +34,7 @@
#include "codegen/llvm-codegen.h"
#include "common/init.h"
#include "common/object-pool.h"
+#include "exprs/ai-functions.inline.h"
#include "exprs/anyval-util.h"
#include "exprs/is-null-predicate.h"
#include "exprs/like-predicate.h"
@@ -82,6 +83,8 @@
DECLARE_bool(abort_on_config_error);
DECLARE_bool(disable_optimization_passes);
DECLARE_string(hdfs_zone_info_zip);
+DECLARE_string(ai_endpoint);
+DECLARE_string(ai_model);
namespace posix_time = boost::posix_time;
using boost::bad_lexical_cast;
@@ -11167,6 +11170,185 @@ TEST_P(ExprTest, Utf8Test) {
executor_->PopExecOption();
}
+TEST_P(ExprTest, AiFunctionsTest) {
+ // Hack up a function context.
+ RuntimeState state(TQueryCtx(), ExecEnv::GetInstance());
+ MemTracker m;
+ MemPool pool(&m);
+ FunctionContext::TypeDesc str_desc;
+ str_desc.type = FunctionContext::Type::TYPE_STRING;
+ std::vector<FunctionContext::TypeDesc> v(3, str_desc);
+ FunctionContext* ctx = CreateUdfTestContext(str_desc, v, &state, &pool);
+ // dummy api key.
+ string secret_key("do_not_share");
+ AiFunctions::set_api_key(secret_key);
+ // valid endpoint
+ StringVal openai_endpoint("https://openai.azure.com");
+ // empty jceks secret key
+ StringVal jceks_secret("");
+ // dummy model.
+ StringVal model("bot");
+ // prompt message.
+ StringVal prompt("hello!");
+ // additional params
+ StringVal json_params;
+ // dry_run to receive HTTP request header and body
+ bool dry_run = true;
+
+ // Test fastpath
+ StringVal result = AiFunctions::AiGenerateTextInternal<true>(ctx,
StringVal::null(),
+ prompt, StringVal::null(), StringVal::null(), StringVal::null(),
dry_run);
+ EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
+ string("https://api.openai.com/v1/chat/completions"
+ "\nContent-Type: application/json"
+ "\nAuthorization: Bearer do_not_share"
+
"\n{\"model\":\"gpt-4\",\"messages\":[{\"role\":\"user\",\"content\":"
+ "\"hello!\"}]}"));
+
+ // Test endpoints.
+ // endpoints must begin with https.
+ result = AiFunctions::AiGenerateTextInternal<false>(
+ ctx, StringVal("http://ai.com"), prompt, model, jceks_secret,
json_params, dry_run);
+ EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
+ AiFunctions::AI_GENERATE_TXT_INVALID_PROTOCOL_ERROR);
+ // only OpenAI endpoints are supported.
+ result = AiFunctions::AiGenerateTextInternal<false>(ctx,
StringVal("https://ai.com"),
+ prompt, model, jceks_secret, json_params, dry_run);
+ EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
+ AiFunctions::AI_GENERATE_TXT_UNSUPPORTED_ENDPOINT_ERROR);
+ // valid request using OpenAI endpoint.
+ result = AiFunctions::AiGenerateTextInternal<false>(
+ ctx, openai_endpoint, prompt, model, jceks_secret, json_params, dry_run);
+ EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
+ string("https://openai.azure.com"
+ "\nContent-Type: application/json"
+ "\nAuthorization: Bearer do_not_share"
+
"\n{\"model\":\"bot\",\"messages\":[{\"role\":\"user\",\"content\":\"hello!"
+ "\"}]}"));
+
+ // Test prompt.
+ // prompt cannot be empty.
+ StringVal invalid_prompt("");
+ result = AiFunctions::AiGenerateTextInternal<false>(
+ ctx, openai_endpoint, invalid_prompt, model, jceks_secret, json_params,
dry_run);
+ EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
+ AiFunctions::AI_GENERATE_TXT_INVALID_PROMPT_ERROR);
+ // prompt cannot be null.
+ invalid_prompt = StringVal::null();
+ result = AiFunctions::AiGenerateTextInternal<false>(
+ ctx, openai_endpoint, invalid_prompt, model, jceks_secret, json_params,
dry_run);
+ EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
+ AiFunctions::AI_GENERATE_TXT_INVALID_PROMPT_ERROR);
+
+ // Test override/additional params
+ // invalid json results in error.
+ StringVal invalid_json_params("{\"temperature\": 0.49, \"stop\":
[\"*\",::,]}");
+ result = AiFunctions::AiGenerateTextInternal<false>(
+ ctx, openai_endpoint, prompt, model, jceks_secret, invalid_json_params,
dry_run);
+ EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
+ AiFunctions::AI_GENERATE_TXT_JSON_PARSE_ERROR);
+ // valid json results in overriding existing params ('model'), and adding
new parms
+ // like 'temperature' and 'stop'.
+ StringVal valid_json_params(
+ "{\"model\": \"gpt\", \"temperature\": 0.49, \"stop\": [\"*\", \"%\"]}");
+ result = AiFunctions::AiGenerateTextInternal<false>(
+ ctx, openai_endpoint, prompt, model, jceks_secret, valid_json_params,
dry_run);
+ EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
+ string("https://openai.azure.com"
+ "\nContent-Type: application/json"
+ "\nAuthorization: Bearer do_not_share"
+
"\n{\"model\":\"gpt\",\"messages\":[{\"role\":\"user\",\"content\":\"hello!"
+ "\"}],\"temperature\":0.49,\"stop\":[\"*\",\"%\"]}"));
+ // messages cannot be overriden, as they we constructed from the prompt.
+ StringVal forbidden_msg_override(
+ "{\"messages\": [{\"role\":\"system\",\"content\":\"howdy!\"}]}");
+ result = AiFunctions::AiGenerateTextInternal<false>(
+ ctx, openai_endpoint, prompt, model, jceks_secret,
forbidden_msg_override, dry_run);
+ EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
+ AiFunctions::AI_GENERATE_TXT_MSG_OVERRIDE_FORBIDDEN_ERROR);
+ // 'n != 1' cannot be overriden as additional params
+ StringVal forbidden_n_value("{\"n\": 2}");
+ result = AiFunctions::AiGenerateTextInternal<false>(
+ ctx, openai_endpoint, prompt, model, jceks_secret, forbidden_n_value,
dry_run);
+ EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
+ AiFunctions::AI_GENERATE_TXT_N_OVERRIDE_FORBIDDEN_ERROR);
+ // non integer value of 'n' cannot be overriden as additional params
+ StringVal forbidden_n_type("{\"n\": \"1\"}");
+ result = AiFunctions::AiGenerateTextInternal<false>(
+ ctx, openai_endpoint, prompt, model, jceks_secret, forbidden_n_type,
dry_run);
+ EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
+ AiFunctions::AI_GENERATE_TXT_N_OVERRIDE_FORBIDDEN_ERROR);
+ // accept 'n=1' override as additional params
+ StringVal allowed_n_override("{\"n\": 1}");
+ result = AiFunctions::AiGenerateTextInternal<false>(
+ ctx, openai_endpoint, prompt, model, jceks_secret, allowed_n_override,
dry_run);
+ EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
+ string("https://openai.azure.com"
+ "\nContent-Type: application/json"
+ "\nAuthorization: Bearer do_not_share"
+
"\n{\"model\":\"bot\",\"messages\":[{\"role\":\"user\",\"content\":\"hello!"
+ "\"}],\"n\":1}"));
+
+ // Test flag file options are used when input is empty/null
+ result = AiFunctions::AiGenerateTextInternal<false>(ctx, StringVal::null(),
prompt,
+ StringVal::null(), jceks_secret, json_params, dry_run);
+ EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
+ string("https://api.openai.com/v1/chat/completions"
+ "\nContent-Type: application/json"
+ "\nAuthorization: Bearer do_not_share"
+
"\n{\"model\":\"gpt-4\",\"messages\":[{\"role\":\"user\",\"content\":"
+ "\"hello!\"}]}"));
+ result = AiFunctions::AiGenerateTextInternal<false>(
+ ctx, StringVal(""), prompt, StringVal(""), jceks_secret, json_params,
dry_run);
+ EXPECT_EQ(string(reinterpret_cast<char*>(result.ptr), result.len),
+ string("https://api.openai.com/v1/chat/completions"
+ "\nContent-Type: application/json"
+ "\nAuthorization: Bearer do_not_share"
+
"\n{\"model\":\"gpt-4\",\"messages\":[{\"role\":\"user\",\"content\":"
+ "\"hello!\"}]}"));
+
+ // Test OPEN AI's API response parsing
+ string content(
+ "A null-terminated string is a character string in a programming "
+ "language like C and C++ that ends with a null character (\'\\\\0\') .
This "
+ "character represents the end of the string and is used to determine the
"
+ "conclusion of the text. Essentially, it is a sequence of characters "
+ "followed by a null byte.");
+ std::ostringstream response;
+ response << "{\"id\": \"chatcmpl-9CGu8eeg1WKbKXGaNrCyHE38mQX90\","
+ << "\"object\": \"chat.completion\","
+ << "\"created\": 1712711944,"
+ << "\"model\": \"gpt-4-0613\","
+ << "\"choices\": ["
+ << "{"
+ << "\"index\": 0,"
+ << "\"message\": {"
+ << "\"role\": \"assistant\","
+ << "\"content\": " << "\"" << content << "\""
+ << "},"
+ << "\"logprobs\": null,"
+ << "\"finish_reason\": \"stop\""
+ << "}"
+ << "],"
+ << "\"usage\": {"
+ << "\"prompt_tokens\": 13,"
+ << "\"completion_tokens\": 60,"
+ << "\"total_tokens\": 73"
+ << "},"
+ << "\"system_fingerprint\": null}";
+ std::string_view res =
AiFunctions::AiGenerateTextParseOpenAiResponse(response.str());
+ string from_null("(\'\\\\0\')");
+ string to_null("(\'\\0\')");
+ size_t pos = content.find(from_null);
+ content.replace(pos, from_null.length(), to_null);
+ EXPECT_EQ(string(res), content);
+
+ // resource cleanup
+ pool.FreeAll();
+ UdfTestHarness::CloseContext(ctx);
+ state.ReleaseResources();
+}
+
} // namespace impala
INSTANTIATE_TEST_CASE_P(Instantiations, ExprTest, ::testing::Values(
diff --git a/be/src/exprs/scalar-expr-evaluator.cc
b/be/src/exprs/scalar-expr-evaluator.cc
index 2d3ec0404..e4d9cb3fd 100644
--- a/be/src/exprs/scalar-expr-evaluator.cc
+++ b/be/src/exprs/scalar-expr-evaluator.cc
@@ -22,6 +22,7 @@
#include "common/object-pool.h"
#include "common/status.h"
#include "exprs/aggregate-functions.h"
+#include "exprs/ai-functions.h"
#include "exprs/anyval-util.h"
#include "exprs/bit-byte-functions.h"
#include "exprs/case-expr.h"
@@ -449,6 +450,7 @@ DateVal ScalarExprEvaluator::GetDateVal(const TupleRow*
row) {
void ScalarExprEvaluator::InitBuiltinsDummy() {
// Call one function from each of the classes to pull all the symbols
// from that class in.
+ AiFunctions::is_api_endpoint_supported("");
AggregateFunctions::InitNull(nullptr, nullptr);
BitByteFunctions::CountSet(nullptr, TinyIntVal::null());
CastFunctions::CastToBooleanVal(nullptr, TinyIntVal::null());
diff --git a/be/src/runtime/exec-env.cc b/be/src/runtime/exec-env.cc
index 506547413..61dff495c 100644
--- a/be/src/runtime/exec-env.cc
+++ b/be/src/runtime/exec-env.cc
@@ -28,6 +28,7 @@
#include "common/logging.h"
#include "common/object-pool.h"
#include "exec/kudu/kudu-util.h"
+#include "exprs/ai-functions.h"
#include "kudu/rpc/service_if.h"
#include "rpc/rpc-mgr.h"
#include "runtime/bufferpool/buffer-pool.h"
@@ -152,6 +153,8 @@ DECLARE_int32(state_store_2_port);
DECLARE_string(debug_actions);
DECLARE_string(ssl_client_ca_certificate);
+DECLARE_string(ai_api_key_jceks_secret);
+
DEFINE_int32(backend_client_connection_num_retries, 3, "Retry backend
connections.");
// When network is unstable, TCP will retry and sending could take longer time.
// Choose 5 minutes as default timeout because we don't want RPC timeout be
triggered
@@ -517,6 +520,16 @@ Status ExecEnv::Init() {
RETURN_IF_ERROR(admission_controller_->Init());
RETURN_IF_ERROR(InitHadoopConfig());
+
+ // If 'ai_api_key_jceks_secret' is set then extract the api_key and populate
+ // AIFunctions::ai_api_key_
+ if (frontend_ != nullptr && FLAGS_ai_api_key_jceks_secret != "") {
+ string api_key;
+ RETURN_IF_ERROR(
+ frontend_->GetSecretFromKeyStore(FLAGS_ai_api_key_jceks_secret,
&api_key));
+ AiFunctions::set_api_key(api_key);
+ }
+
return Status::OK();
}
diff --git a/be/src/service/frontend.cc b/be/src/service/frontend.cc
index 97d58ea77..6cc00a710 100644
--- a/be/src/service/frontend.cc
+++ b/be/src/service/frontend.cc
@@ -143,17 +143,26 @@ Frontend::Frontend() {
{"commitKuduTransaction", "([B)V", &commit_kudu_txn_}
};
+ JniMethodDescriptor staticMethods[] = {
+ {"getSecretFromKeyStore", "([B)Ljava/lang/String;",
&get_secret_from_key_store_}
+ };
+
JNIEnv* jni_env = JniUtil::GetJNIEnv();
JniLocalFrame jni_frame;
ABORT_IF_ERROR(jni_frame.push(jni_env));
// create instance of java class JniFrontend
- jclass fe_class = jni_env->FindClass(FLAGS_jni_frontend_class.c_str());
+ fe_class_ = jni_env->FindClass(FLAGS_jni_frontend_class.c_str());
ABORT_IF_EXC(jni_env);
uint32_t num_methods = sizeof(methods) / sizeof(methods[0]);
for (int i = 0; i < num_methods; ++i) {
- ABORT_IF_ERROR(JniUtil::LoadJniMethod(jni_env, fe_class, &(methods[i])));
+ ABORT_IF_ERROR(JniUtil::LoadJniMethod(jni_env, fe_class_, &(methods[i])));
+ };
+
+ num_methods = sizeof(staticMethods) / sizeof(staticMethods[0]);
+ for (int i = 0; i < num_methods; ++i) {
+ ABORT_IF_ERROR(JniUtil::LoadStaticJniMethod(jni_env, fe_class_,
&(staticMethods[i])));
};
jbyteArray cfg_bytes;
@@ -162,7 +171,7 @@ Frontend::Frontend() {
// Pass in whether this is a backend test, so that the Frontend can avoid
certain
// unnecessary initialization that introduces dependencies on a running
minicluster.
jboolean is_be_test = TestInfo::is_be_test();
- jobject fe = jni_env->NewObject(fe_class, fe_ctor_, cfg_bytes, is_be_test);
+ jobject fe = jni_env->NewObject(fe_class_, fe_ctor_, cfg_bytes, is_be_test);
ABORT_IF_EXC(jni_env);
ABORT_IF_ERROR(JniUtil::LocalToGlobalRef(jni_env, fe, &fe_));
}
@@ -407,3 +416,10 @@ Status Frontend::CommitKuduTransaction(const TUniqueId&
query_id) {
Status Frontend::Convert(const TExecRequest& request) {
return JniUtil::CallJniMethod(fe_, convertTable, request);
}
+
+Status Frontend::GetSecretFromKeyStore(const string& secret_key, string*
secret) {
+ TStringLiteral secret_key_t;
+ secret_key_t.__set_value(secret_key);
+ return JniUtil::CallStaticJniMethod(fe_class_, get_secret_from_key_store_,
secret_key_t,
+ secret);
+}
diff --git a/be/src/service/frontend.h b/be/src/service/frontend.h
index afdf0f808..823f98ca8 100644
--- a/be/src/service/frontend.h
+++ b/be/src/service/frontend.h
@@ -246,7 +246,11 @@ class Frontend {
/// Convert external Hdfs tables to Iceberg tables
Status Convert(const TExecRequest& request);
+ /// Get secret from jceks key store for the input secret_key.
+ Status GetSecretFromKeyStore(const string& secret_key, string* secret);
+
private:
+ jclass fe_class_; // org.apache.impala.service.JniFrontend class
jobject fe_; // instance of org.apache.impala.service.JniFrontend
jmethodID create_exec_request_id_; // JniFrontend.createExecRequest()
jmethodID get_explain_plan_id_; // JniFrontend.getExplainPlan()
@@ -287,6 +291,7 @@ class Frontend {
jmethodID abort_kudu_txn_; // JniFrontend.abortKuduTransaction()
jmethodID commit_kudu_txn_; // JniFrontend.commitKuduTransaction()
jmethodID convertTable; // JniFrontend.convertTable
+ jmethodID get_secret_from_key_store_; // JniFrontend.getSecretFromKeyStore()
// Only used for testing.
jmethodID build_test_descriptor_table_id_; //
JniFrontend.buildTestDescriptorTable()
diff --git a/be/src/udf/udf-internal.h b/be/src/udf/udf-internal.h
index 423e38f6b..92c1e8b42 100644
--- a/be/src/udf/udf-internal.h
+++ b/be/src/udf/udf-internal.h
@@ -29,7 +29,7 @@
#include "udf/udf.h"
namespace impala {
-
+using impala_udf::BuiltInFunctions;
#define RETURN_IF_NULL(ctx, ptr) \
do { \
if (UNLIKELY(ptr == NULL)) { \
@@ -287,6 +287,9 @@ class FunctionContextImpl {
/// Indicates whether this context has been closed. Used for
verification/debugging.
bool closed_;
+
+ /// Built-in functions exposed to UDFs
+ BuiltInFunctions functions_;
};
}
diff --git a/be/src/udf/udf.cc b/be/src/udf/udf.cc
index 409b71994..dee3fc446 100644
--- a/be/src/udf/udf.cc
+++ b/be/src/udf/udf.cc
@@ -116,10 +116,26 @@ class RuntimeState {
const std::string user_string_ = "";
};
+// Dummy AiFunctions class for UDF SDK
+static const std::string AI_FUNCTIONS_DUMMY_RESPONSE = "dummy response";
+using impala_udf::StringVal;
+using impala_udf::FunctionContext;
+class AiFunctions {
+ public:
+ static StringVal AiGenerateText(FunctionContext* ctx, const StringVal&
endpoint,
+ const StringVal& prompt, const StringVal& model,
+ const StringVal& api_key_jceks_secret, const StringVal& params) {
+ return StringVal(AI_FUNCTIONS_DUMMY_RESPONSE.c_str());
+ }
+ static StringVal AiGenerateTextDefault(FunctionContext* ctx, const
StringVal& prompt) {
+ return StringVal(AI_FUNCTIONS_DUMMY_RESPONSE.c_str());
+ }
+};
}
#else
#include "common/atomic.h"
+#include "exprs/ai-functions.h"
#include "exprs/anyval-util.h"
#include "runtime/free-pool.h"
#include "runtime/mem-pool.h"
@@ -190,6 +206,9 @@ FunctionContext*
FunctionContextImpl::CreateContext(RuntimeState* state,
aligned_malloc(varargs_buffer_size, VARARGS_BUFFER_ALIGNMENT));
ctx->impl_->varargs_buffer_size_ = varargs_buffer_size;
ctx->impl_->debug_ = debug;
+ ctx->impl_->functions_.ai_generate_text =
impala::AiFunctions::AiGenerateText;
+ ctx->impl_->functions_.ai_generate_text_default =
+ impala::AiFunctions::AiGenerateTextDefault;
VLOG_ROW << "Created FunctionContext: " << ctx;
return ctx;
}
@@ -204,8 +223,7 @@ FunctionContext* FunctionContextImpl::Clone(
return new_context;
}
-FunctionContext::FunctionContext() : impl_(new FunctionContextImpl(this)) {
-}
+FunctionContext::FunctionContext() : impl_(new FunctionContextImpl(this)) {}
FunctionContext::~FunctionContext() {
assert(impl_->closed_ && "FunctionContext wasn't closed!");
@@ -475,6 +493,10 @@ void FunctionContext::SetFunctionState(FunctionStateScope
scope, void* ptr) {
}
}
+const BuiltInFunctions* FunctionContext::Functions() const {
+ return &impl_->functions_;
+}
+
uint8_t* FunctionContextImpl::AllocateForResults(int64_t byte_size) noexcept {
assert(!closed_);
#if !defined(NDEBUG) && !defined(IMPALA_UDF_SDK_BUILD)
diff --git a/be/src/udf/udf.h b/be/src/udf/udf.h
index 45d1b2295..2f34a44fd 100644
--- a/be/src/udf/udf.h
+++ b/be/src/udf/udf.h
@@ -61,6 +61,21 @@ struct BigIntVal;
struct StringVal;
struct TimestampVal;
struct DateVal;
+class FunctionContext;
+
+/// Built-in functions exposed to UDFs
+struct BuiltInFunctions {
+ public:
+ /// Sends a prompt to the default endpoint and uses the default model,
default
+ /// jceks api-key secret and default params.
+ StringVal (*ai_generate_text_default)(
+ FunctionContext* context, const StringVal& prompt);
+ /// Sends a prompt to the input AI endpoint using the input model, jceks
api_key secret
+ /// and optional params.
+ StringVal (*ai_generate_text)(FunctionContext* context, const StringVal&
endpoint,
+ const StringVal& prompt, const StringVal& model,
+ const StringVal& api_key_jceks_secret, const StringVal& params);
+};
/// A FunctionContext is passed to every UDF/UDA and is the interface for the
UDF to the
/// rest of the system. It contains APIs to examine the system state, report
errors and
@@ -246,6 +261,8 @@ class FunctionContext {
/// use this. This is used internally.
impala::FunctionContextImpl* impl() const { return impl_; }
+ const BuiltInFunctions* Functions() const;
+
~FunctionContext();
private:
diff --git a/be/src/udf_samples/udf-sample.cc b/be/src/udf_samples/udf-sample.cc
index 0aca16483..2c3d380ec 100644
--- a/be/src/udf_samples/udf-sample.cc
+++ b/be/src/udf_samples/udf-sample.cc
@@ -17,6 +17,8 @@
#include "udf-sample.h"
+#include <string>
+
// In this sample we are declaring a UDF that adds two ints and returns an int.
IMPALA_UDF_EXPORT
IntVal AddUdf(FunctionContext* context, const IntVal& arg1, const IntVal&
arg2) {
@@ -26,3 +28,29 @@ IntVal AddUdf(FunctionContext* context, const IntVal& arg1,
const IntVal& arg2)
// Multiple UDFs can be defined in the same file
+// Classify input customer reviews.
+IMPALA_UDF_EXPORT
+StringVal ClassifyReviewsDefault(FunctionContext* context, const StringVal&
input) {
+ std::string request =
+ std::string("Classify the following review as positive, neutral, or
negative")
+ + std::string(" and only include the uncapitalized category in the
response: ")
+ + std::string(reinterpret_cast<const char*>(input.ptr), input.len);
+ StringVal prompt(request.c_str());
+ return context->Functions()->ai_generate_text_default(context, prompt);
+}
+
+// Classify input customer reviews.
+IMPALA_UDF_EXPORT
+StringVal ClassifyReviews(FunctionContext* context, const StringVal& input) {
+ std::string request =
+ std::string("Classify the following review as positive, neutral, or
negative")
+ + std::string(" and only include the uncapitalized category in the
response: ")
+ + std::string(reinterpret_cast<const char*>(input.ptr), input.len);
+ StringVal prompt(request.c_str());
+ const StringVal endpoint("https://api.openai.com/v1/chat/completions");
+ const StringVal model("gpt-3.5-turbo");
+ const StringVal api_key_jceks_secret("open-ai-key");
+ const StringVal params("{\"temperature\": 0.9, \"model\": \"gpt-4\"}");
+ return context->Functions()->ai_generate_text(
+ context, endpoint, prompt, model, api_key_jceks_secret, params);
+}
diff --git a/be/src/udf_samples/udf-sample.h b/be/src/udf_samples/udf-sample.h
index b47b9e489..5ff8e7c5d 100644
--- a/be/src/udf_samples/udf-sample.h
+++ b/be/src/udf_samples/udf-sample.h
@@ -25,4 +25,7 @@ using namespace impala_udf;
IntVal AddUdf(FunctionContext* context, const IntVal& arg1, const IntVal&
arg2);
+StringVal ClassifyReviewsDefault(FunctionContext* context, const StringVal&
input);
+
+StringVal ClassifyReviews(FunctionContext* context, const StringVal& input);
#endif
diff --git a/be/src/util/jni-util.h b/be/src/util/jni-util.h
index 5005315bb..85dfc57b9 100644
--- a/be/src/util/jni-util.h
+++ b/be/src/util/jni-util.h
@@ -377,6 +377,10 @@ class JniUtil {
return JniCall::static_method(cls, method).Call();
}
+ template <typename T, typename R>
+ static Status CallStaticJniMethod(const jclass& cls, const jmethodID& method,
+ const T& arg, R* response) WARN_UNUSED_RESULT;
+
template <typename T>
static Status CallJniMethod(const jobject& obj, const jmethodID& method,
const T& arg) WARN_UNUSED_RESULT;
@@ -434,6 +438,12 @@ SPECIALIZE_PRIMITIVE_TO_VALUE(float, f);
SPECIALIZE_PRIMITIVE_TO_VALUE(double, d);
#undef SPECIALIZE_PRIMITIVE_TO_VALUE
+template <typename T, typename R>
+inline Status JniUtil::CallStaticJniMethod(const jclass& cls, const jmethodID&
method,
+ const T& arg, R* response) {
+ return JniCall::static_method(cls,
method).with_thrift_arg(arg).Call(response);
+}
+
template <typename T>
inline Status JniUtil::CallJniMethod(const jobject& obj, const jmethodID&
method,
const T& arg) {
diff --git a/bin/load-data.py b/bin/load-data.py
index 729dcb95b..57ad313de 100755
--- a/bin/load-data.py
+++ b/bin/load-data.py
@@ -87,6 +87,7 @@ WORKLOAD_DIR = options.workload_dir
DATASET_DIR = options.dataset_dir
TESTDATA_BIN_DIR = os.path.join(os.environ['IMPALA_HOME'], 'testdata/bin')
AVRO_SCHEMA_DIR = "avro_schemas"
+TESTDATA_JCEKS_DIR = os.path.join(os.environ['IMPALA_HOME'], 'testdata/jceks')
GENERATE_SCHEMA_CMD = "generate-schema-statements.py --exploration_strategy=%s
"\
"--workload=%s --scale_factor=%s --verbose"
@@ -299,6 +300,14 @@ def hive_exec_query_files_parallel(thread_pool,
query_files, step_name):
exec_query_files_parallel(thread_pool, query_files, 'hive', step_name)
+def exec_hadoop_credential_cmd(secret_key, secret, provider_path,
exit_on_error=True):
+ cmd = ("%s credential create %s -value %s -provider %s"
+ % (HADOOP_CMD, secret_key, secret, provider_path))
+ LOG.info("Executing Hadoop command: " + cmd)
+ exec_cmd(cmd, error_msg="Error executing Hadoop command, exiting",
+ exit_on_error=exit_on_error)
+
+
def main():
logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%H:%M:%S')
LOG.setLevel(logging.DEBUG)
@@ -308,6 +317,12 @@ def main():
#
LOG.debug(' '.join(sys.argv))
+ jceks_path = TESTDATA_JCEKS_DIR + "/test.jceks"
+ if os.path.exists(jceks_path):
+ os.remove(jceks_path)
+ exec_hadoop_credential_cmd("openai-api-key-secret", "secret",
+ "localjceks://file" + jceks_path)
+
all_workloads = available_workloads(WORKLOAD_DIR)
workloads = []
if options.workloads is None:
diff --git a/bin/rat_exclude_files.txt b/bin/rat_exclude_files.txt
index 5e9b5a0d2..51f6acadc 100644
--- a/bin/rat_exclude_files.txt
+++ b/bin/rat_exclude_files.txt
@@ -189,6 +189,7 @@
testdata/impala-profiles/impala_profile_log_tpcds_compute_stats_v2
testdata/impala-profiles/impala_profile_log_tpcds_compute_stats_v2_default.expected.txt
testdata/impala-profiles/impala_profile_log_tpcds_compute_stats_v2_extended.expected.txt
testdata/hive_benchmark/grepTiny/part-00000
+testdata/jceks/.gitkeep
testdata/jwt/*.json
testdata/jwt/jwt_expired
testdata/jwt/jwt_signed
diff --git a/common/function-registry/impala_functions.py
b/common/function-registry/impala_functions.py
index 056eec1df..8f61b80bb 100644
--- a/common/function-registry/impala_functions.py
+++ b/common/function-registry/impala_functions.py
@@ -513,6 +513,10 @@ visible_functions = [
[['repeat'], 'STRING', ['STRING', 'BIGINT'],
'impala::StringFunctions::Repeat'],
[['lpad'], 'STRING', ['STRING', 'BIGINT', 'STRING'],
'impala::StringFunctions::Lpad'],
[['rpad'], 'STRING', ['STRING', 'BIGINT', 'STRING'],
'impala::StringFunctions::Rpad'],
+ [['ai_generate_text'], 'STRING', ['STRING', 'STRING', 'STRING', 'STRING',
'STRING'],
+ 'impala::AiFunctions::AiGenerateText'],
+ [['ai_generate_text_default'], 'STRING', ['STRING'],
+ 'impala::AiFunctions::AiGenerateTextDefault'],
[['bytes'], 'INT', ['STRING'], 'impala::StringFunctions::Bytes'],
[['length'], 'INT', ['STRING'], 'impala::StringFunctions::Length'],
[['length'], 'INT', ['BINARY'], 'impala::StringFunctions::Bytes'],
diff --git a/fe/src/main/java/org/apache/impala/service/JniFrontend.java
b/fe/src/main/java/org/apache/impala/service/JniFrontend.java
index b4603c031..5f08f0ab3 100644
--- a/fe/src/main/java/org/apache/impala/service/JniFrontend.java
+++ b/fe/src/main/java/org/apache/impala/service/JniFrontend.java
@@ -94,6 +94,7 @@ import org.apache.impala.thrift.TShowGrantPrincipalParams;
import org.apache.impala.thrift.TShowRolesParams;
import org.apache.impala.thrift.TShowStatsOp;
import org.apache.impala.thrift.TShowStatsParams;
+import org.apache.impala.thrift.TStringLiteral;
import org.apache.impala.thrift.TDescribeHistoryParams;
import org.apache.impala.thrift.TSessionState;
import org.apache.impala.thrift.TTableName;
@@ -133,6 +134,8 @@ public class JniFrontend {
private final static TBinaryProtocol.Factory protocolFactory_ =
new TBinaryProtocol.Factory();
private final Frontend frontend_;
+ public final static String KEYSTORE_ERROR_MSG = "Failed to get password
from" +
+ "keystore, error: invalid key '%s' or password doesn't exist";
/**
* Create a new instance of the Jni Frontend.
@@ -805,6 +808,31 @@ public class JniFrontend {
}
}
+ /**
+ * Returns secret from the configured KeyStore.
+ * @param secretKeyRequest the serialized secret key to be used for
extracting secret.
+ */
+ public static String getSecretFromKeyStore(byte[] secretKeyRequest)
+ throws ImpalaException {
+ final TStringLiteral secretKey = new TStringLiteral();
+ JniUtil.deserializeThrift(protocolFactory_, secretKey, secretKeyRequest);
+ String secret = null;
+ try {
+ char[] secretCharArray = CONF.getPassword(secretKey.getValue());
+ if (secretCharArray != null) {
+ secret = new String(secretCharArray);
+ } else {
+ String errMsg = String.format(KEYSTORE_ERROR_MSG,
secretKey.getValue());
+ LOG.error(errMsg);
+ throw new InternalException(errMsg);
+ }
+ } catch (IOException e) {
+ LOG.error("Failed to get password from keystore, error: " + e);
+ throw new InternalException(e.getMessage());
+ }
+ return secret;
+ }
+
public String validateSaml2Bearer(byte[] serializedRequest) throws
ImpalaException{
Preconditions.checkNotNull(frontend_);
Preconditions.checkNotNull(frontend_.getSaml2Client());
diff --git a/fe/src/test/java/org/apache/impala/service/JniFrontendTest.java
b/fe/src/test/java/org/apache/impala/service/JniFrontendTest.java
index 771b1c785..e5314ce56 100644
--- a/fe/src/test/java/org/apache/impala/service/JniFrontendTest.java
+++ b/fe/src/test/java/org/apache/impala/service/JniFrontendTest.java
@@ -33,7 +33,10 @@ import
org.apache.hadoop.security.JniBasedUnixGroupsNetgroupMappingWithFallback;
import org.apache.hadoop.security.ShellBasedUnixGroupsMapping;
import org.apache.hadoop.security.ShellBasedUnixGroupsNetgroupMapping;
import org.apache.impala.common.ImpalaException;
+import org.apache.impala.common.InternalException;
+import org.apache.impala.common.JniUtil;
import org.apache.impala.thrift.TBackendGflags;
+import org.apache.impala.thrift.TStringLiteral;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
@@ -123,4 +126,26 @@ public class JniFrontendTest {
socketDir.getParentFile().delete();
}
}
+
+ /**
+ * This test validates that JniFrontend::getSecretFromKeyStore function can
return
+ * the secret from the configured Jceks KeyStore
+ */
+ @Test
+ public void testGetSecretFromKeyStore() throws ImpalaException {
+ // valid secret-key returns the correct secret
+ TStringLiteral secretKey = new TStringLiteral("openai-api-key-secret");
+ byte[] secretKeyBytes = JniUtil.serializeToThrift(secretKey);
+ String secret = JniFrontend.getSecretFromKeyStore(secretKeyBytes);
+ assertEquals(secret, "secret");
+ // invalid secret-key returns error
+ secretKey = new TStringLiteral("dummy-secret");
+ secretKeyBytes = JniUtil.serializeToThrift(secretKey);
+ try {
+ secret = JniFrontend.getSecretFromKeyStore(secretKeyBytes);
+ } catch (InternalException e) {
+ assertEquals(e.getMessage(),
+ String.format(JniFrontend.KEYSTORE_ERROR_MSG, secretKey.getValue()));
+ }
+ }
}
diff --git
a/testdata/cluster/node_templates/common/etc/hadoop/conf/core-site.xml.py
b/testdata/cluster/node_templates/common/etc/hadoop/conf/core-site.xml.py
index a80d088ab..72f1ac233 100644
--- a/testdata/cluster/node_templates/common/etc/hadoop/conf/core-site.xml.py
+++ b/testdata/cluster/node_templates/common/etc/hadoop/conf/core-site.xml.py
@@ -24,6 +24,9 @@ import sys
kerberize = os.environ.get('IMPALA_KERBERIZE') == 'true'
target_filesystem = os.environ.get('TARGET_FILESYSTEM')
+jceks_keystore = ("localjceks://file" +
+ os.path.join(os.environ['IMPALA_HOME'], 'testdata/jceks/test.jceks'))
+
compression_codecs = [
'org.apache.hadoop.io.compress.GzipCodec',
'org.apache.hadoop.io.compress.DefaultCodec',
@@ -59,6 +62,9 @@ CONFIG = {
# Location of the KMS key provider
'hadoop.security.key.provider.path':
'kms://http@${INTERNAL_LISTEN_HOST}:9600/kms',
+ # Location of Jceks KeyStore
+ 'hadoop.security.credential.provider.path': jceks_keystore,
+
# Needed as long as multiple nodes are running on the same address. For
Impala
# testing only.
'yarn.scheduler.include-port-in-node-name': 'true',
diff --git a/testdata/jceks/.gitkeep b/testdata/jceks/.gitkeep
new file mode 100644
index 000000000..e69de29bb