This is an automated email from the ASF dual-hosted git repository. michaelsmith pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/impala.git
commit 6a079be2909714652da8de0f7d4af83ae0d1097c Author: Yida Wu <[email protected]> AuthorDate: Tue Apr 16 12:43:14 2024 -0700 IMPALA-13004: Fix heap-use-after-free error in ExprTest AiFunctionsTest The issue is that the code previously used a std::string_view to hold the data which is actually returned by rapidjson::Document. However, the rapidjson::Document object gets destroyed after creating the std::string_view. This meant the std::string_view referenced memory that was no longer valid, leading to a heap-use-after-free error. This patch fixes this issue by modifying the function to return a std::string instead of a std::string_view. When the function returns a string, it creates a copy of the data from rapidjson::Document. This ensures the returned string has its own memory allocation and doesn't rely on the destroyed rapidjson::Document. Tests: Reran the asan build and passed. Change-Id: I3bb9dcf9d72cce7ad37d5bc25821cf6ee55a8ab5 Reviewed-on: http://gerrit.cloudera.org:8080/21315 Reviewed-by: Impala Public Jenkins <[email protected]> Tested-by: Impala Public Jenkins <[email protected]> --- be/src/exprs/ai-functions-ir.cc | 5 ++--- be/src/exprs/ai-functions.h | 2 +- be/src/exprs/ai-functions.inline.h | 9 ++++++--- be/src/exprs/expr-test.cc | 4 ++-- 4 files changed, 11 insertions(+), 9 deletions(-) diff --git a/be/src/exprs/ai-functions-ir.cc b/be/src/exprs/ai-functions-ir.cc index e482cb688..6def1a010 100644 --- a/be/src/exprs/ai-functions-ir.cc +++ b/be/src/exprs/ai-functions-ir.cc @@ -85,7 +85,7 @@ bool AiFunctions::is_api_endpoint_supported(const std::string_view& endpoint) { gstrncasestr(endpoint.data(), OPEN_AI_PUBLIC_ENDPOINT, endpoint.size()) != nullptr); } -std::string_view AiFunctions::AiGenerateTextParseOpenAiResponse( +string AiFunctions::AiGenerateTextParseOpenAiResponse( const std::string_view& response) { rapidjson::Document document; document.Parse(response.data(), response.size()); @@ -120,8 +120,7 @@ std::string_view AiFunctions::AiGenerateTextParseOpenAiResponse( return AI_GENERATE_TXT_JSON_PARSE_ERROR; } - const rapidjson::Value& result = message[OPEN_AI_RESPONSE_FIELD_CONTENT]; - return std::string_view(result.GetString(), result.GetStringLength()); + return message[OPEN_AI_RESPONSE_FIELD_CONTENT].GetString(); } StringVal AiFunctions::AiGenerateText(FunctionContext* ctx, const StringVal& endpoint, diff --git a/be/src/exprs/ai-functions.h b/be/src/exprs/ai-functions.h index c1d2e635e..0e6396b40 100644 --- a/be/src/exprs/ai-functions.h +++ b/be/src/exprs/ai-functions.h @@ -64,7 +64,7 @@ class AiFunctions { const StringVal& api_key_jceks_secret, const StringVal& params, const bool dry_run); /// Internal helper function for parsing OPEN AI's API response. Input parameter is the /// json representation of the OPEN AI's API response. - static std::string_view AiGenerateTextParseOpenAiResponse( + static std::string AiGenerateTextParseOpenAiResponse( const std::string_view& reponse); friend class ExprTest_AiFunctionsTest_Test; diff --git a/be/src/exprs/ai-functions.inline.h b/be/src/exprs/ai-functions.inline.h index 7f7bcfd92..9f143e2df 100644 --- a/be/src/exprs/ai-functions.inline.h +++ b/be/src/exprs/ai-functions.inline.h @@ -103,9 +103,12 @@ StringVal AiFunctions::AiGenerateTextInternal(FunctionContext* ctx, payload_allocator); message_array.PushBack(message, payload_allocator); payload.AddMember("messages", message_array, payload_allocator); - // Override additional params + // Override additional params. + // Caution: 'payload' might reference data owned by 'overrides'. + // To ensure valid access, place 'overrides' outside the 'if' + // statement before using 'payload'. + Document overrides; if (!fastpath && params.ptr != nullptr && params.len != 0) { - Document overrides; overrides.Parse(reinterpret_cast<char*>(params.ptr), params.len); if (overrides.HasParseError()) { LOG(WARNING) << AI_GENERATE_TXT_JSON_PARSE_ERROR << ": error code " @@ -172,7 +175,7 @@ StringVal AiFunctions::AiGenerateTextInternal(FunctionContext* ctx, ctx, reinterpret_cast<const uint8_t*>(msg.c_str()), msg.size()); } // Parse the JSON response string - std::string_view response = AiGenerateTextParseOpenAiResponse( + std::string response = AiGenerateTextParseOpenAiResponse( std::string_view(reinterpret_cast<char*>(resp.data()), resp.size())); VLOG(2) << "AI Generate Text: \nresponse: " << response; StringVal result(ctx, response.length()); diff --git a/be/src/exprs/expr-test.cc b/be/src/exprs/expr-test.cc index a05615057..ec326ed5a 100644 --- a/be/src/exprs/expr-test.cc +++ b/be/src/exprs/expr-test.cc @@ -11336,12 +11336,12 @@ TEST_P(ExprTest, AiFunctionsTest) { << "\"total_tokens\": 73" << "}," << "\"system_fingerprint\": null}"; - std::string_view res = AiFunctions::AiGenerateTextParseOpenAiResponse(response.str()); + std::string res = AiFunctions::AiGenerateTextParseOpenAiResponse(response.str()); string from_null("(\'\\\\0\')"); string to_null("(\'\\0\')"); size_t pos = content.find(from_null); content.replace(pos, from_null.length(), to_null); - EXPECT_EQ(string(res), content); + EXPECT_EQ(res, content); // resource cleanup pool.FreeAll();
