This is an automated email from the ASF dual-hosted git repository.

michaelsmith pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/impala.git

commit 6a079be2909714652da8de0f7d4af83ae0d1097c
Author: Yida Wu <[email protected]>
AuthorDate: Tue Apr 16 12:43:14 2024 -0700

    IMPALA-13004: Fix heap-use-after-free error in ExprTest AiFunctionsTest
    
    The issue is that the code previously used a std::string_view to
    hold the data which is actually returned by rapidjson::Document.
    However, the rapidjson::Document object gets destroyed after
    creating the std::string_view. This meant the std::string_view
    referenced memory that was no longer valid, leading to a
    heap-use-after-free error.
    
    This patch fixes this issue by modifying the function to
    return a std::string instead of a std::string_view. When the
    function returns a string, it creates a copy of the
    data from rapidjson::Document. This ensures the returned
    string has its own memory allocation and doesn't rely on
    the destroyed rapidjson::Document.
    
    Tests:
    Reran the asan build and passed.
    
    Change-Id: I3bb9dcf9d72cce7ad37d5bc25821cf6ee55a8ab5
    Reviewed-on: http://gerrit.cloudera.org:8080/21315
    Reviewed-by: Impala Public Jenkins <[email protected]>
    Tested-by: Impala Public Jenkins <[email protected]>
---
 be/src/exprs/ai-functions-ir.cc    | 5 ++---
 be/src/exprs/ai-functions.h        | 2 +-
 be/src/exprs/ai-functions.inline.h | 9 ++++++---
 be/src/exprs/expr-test.cc          | 4 ++--
 4 files changed, 11 insertions(+), 9 deletions(-)

diff --git a/be/src/exprs/ai-functions-ir.cc b/be/src/exprs/ai-functions-ir.cc
index e482cb688..6def1a010 100644
--- a/be/src/exprs/ai-functions-ir.cc
+++ b/be/src/exprs/ai-functions-ir.cc
@@ -85,7 +85,7 @@ bool AiFunctions::is_api_endpoint_supported(const 
std::string_view& endpoint) {
       gstrncasestr(endpoint.data(), OPEN_AI_PUBLIC_ENDPOINT, endpoint.size()) 
!= nullptr);
 }
 
-std::string_view AiFunctions::AiGenerateTextParseOpenAiResponse(
+string AiFunctions::AiGenerateTextParseOpenAiResponse(
     const std::string_view& response) {
   rapidjson::Document document;
   document.Parse(response.data(), response.size());
@@ -120,8 +120,7 @@ std::string_view 
AiFunctions::AiGenerateTextParseOpenAiResponse(
     return AI_GENERATE_TXT_JSON_PARSE_ERROR;
   }
 
-  const rapidjson::Value& result = message[OPEN_AI_RESPONSE_FIELD_CONTENT];
-  return std::string_view(result.GetString(), result.GetStringLength());
+  return message[OPEN_AI_RESPONSE_FIELD_CONTENT].GetString();
 }
 
 StringVal AiFunctions::AiGenerateText(FunctionContext* ctx, const StringVal& 
endpoint,
diff --git a/be/src/exprs/ai-functions.h b/be/src/exprs/ai-functions.h
index c1d2e635e..0e6396b40 100644
--- a/be/src/exprs/ai-functions.h
+++ b/be/src/exprs/ai-functions.h
@@ -64,7 +64,7 @@ class AiFunctions {
       const StringVal& api_key_jceks_secret, const StringVal& params, const 
bool dry_run);
   /// Internal helper function for parsing OPEN AI's API response. Input 
parameter is the
   /// json representation of the OPEN AI's API response.
-  static std::string_view AiGenerateTextParseOpenAiResponse(
+  static std::string AiGenerateTextParseOpenAiResponse(
       const std::string_view& reponse);
 
   friend class ExprTest_AiFunctionsTest_Test;
diff --git a/be/src/exprs/ai-functions.inline.h 
b/be/src/exprs/ai-functions.inline.h
index 7f7bcfd92..9f143e2df 100644
--- a/be/src/exprs/ai-functions.inline.h
+++ b/be/src/exprs/ai-functions.inline.h
@@ -103,9 +103,12 @@ StringVal 
AiFunctions::AiGenerateTextInternal(FunctionContext* ctx,
       payload_allocator);
   message_array.PushBack(message, payload_allocator);
   payload.AddMember("messages", message_array, payload_allocator);
-  // Override additional params
+  // Override additional params.
+  // Caution: 'payload' might reference data owned by 'overrides'.
+  // To ensure valid access, place 'overrides' outside the 'if'
+  // statement before using 'payload'.
+  Document overrides;
   if (!fastpath && params.ptr != nullptr && params.len != 0) {
-    Document overrides;
     overrides.Parse(reinterpret_cast<char*>(params.ptr), params.len);
     if (overrides.HasParseError()) {
       LOG(WARNING) << AI_GENERATE_TXT_JSON_PARSE_ERROR << ": error code "
@@ -172,7 +175,7 @@ StringVal 
AiFunctions::AiGenerateTextInternal(FunctionContext* ctx,
         ctx, reinterpret_cast<const uint8_t*>(msg.c_str()), msg.size());
   }
   // Parse the JSON response string
-  std::string_view response = AiGenerateTextParseOpenAiResponse(
+  std::string response = AiGenerateTextParseOpenAiResponse(
       std::string_view(reinterpret_cast<char*>(resp.data()), resp.size()));
   VLOG(2) << "AI Generate Text: \nresponse: " << response;
   StringVal result(ctx, response.length());
diff --git a/be/src/exprs/expr-test.cc b/be/src/exprs/expr-test.cc
index a05615057..ec326ed5a 100644
--- a/be/src/exprs/expr-test.cc
+++ b/be/src/exprs/expr-test.cc
@@ -11336,12 +11336,12 @@ TEST_P(ExprTest, AiFunctionsTest) {
       << "\"total_tokens\": 73"
       << "},"
       << "\"system_fingerprint\": null}";
-  std::string_view res = 
AiFunctions::AiGenerateTextParseOpenAiResponse(response.str());
+  std::string res = 
AiFunctions::AiGenerateTextParseOpenAiResponse(response.str());
   string from_null("(\'\\\\0\')");
   string to_null("(\'\\0\')");
   size_t pos = content.find(from_null);
   content.replace(pos, from_null.length(), to_null);
-  EXPECT_EQ(string(res), content);
+  EXPECT_EQ(res, content);
 
   // resource cleanup
   pool.FreeAll();

Reply via email to