https://github.com/ashgti updated https://github.com/llvm/llvm-project/pull/152367
>From 02417f51a23fbfd4d941b6f9b18e82fe8eb87566 Mon Sep 17 00:00:00 2001 From: John Harrison <harj...@google.com> Date: Tue, 5 Aug 2025 11:23:31 -0700 Subject: [PATCH] [lldb] Update JSONTransport to use MainLoop for reading. Reapply "[lldb] Update JSONTransport to use MainLoop for reading." (#152155) This reverts commit cd40281685f642ad879e33f3fda8d1faa136ebf4. This also includes some updates to try to address the platforms with failing tests. I updated the JSONTransport and tests to use std::function instead of llvm:unique_function. I think the tests were failing due to the unique_function not being moved correctly in the loop on some platforms. --- lldb/include/lldb/Host/JSONTransport.h | 113 +++++-- lldb/source/Host/common/JSONTransport.cpp | 167 ++++------ lldb/test/API/tools/lldb-dap/io/TestDAP_io.py | 27 +- lldb/tools/lldb-dap/DAP.cpp | 128 ++++---- lldb/tools/lldb-dap/DAP.h | 7 + lldb/tools/lldb-dap/Transport.h | 2 +- lldb/unittests/DAP/DAPTest.cpp | 11 +- lldb/unittests/DAP/TestBase.cpp | 26 +- lldb/unittests/DAP/TestBase.h | 20 ++ lldb/unittests/Host/JSONTransportTest.cpp | 299 +++++++++++++----- .../ProtocolServer/ProtocolMCPServerTest.cpp | 131 ++++---- 11 files changed, 573 insertions(+), 358 deletions(-) diff --git a/lldb/include/lldb/Host/JSONTransport.h b/lldb/include/lldb/Host/JSONTransport.h index 4087cdf2b42f7..98bce6e265356 100644 --- a/lldb/include/lldb/Host/JSONTransport.h +++ b/lldb/include/lldb/Host/JSONTransport.h @@ -13,13 +13,15 @@ #ifndef LLDB_HOST_JSONTRANSPORT_H #define LLDB_HOST_JSONTRANSPORT_H +#include "lldb/Host/MainLoopBase.h" #include "lldb/lldb-forward.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Error.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/JSON.h" -#include <chrono> +#include <string> #include <system_error> +#include <vector> namespace lldb_private { @@ -28,27 +30,33 @@ class TransportEOFError : public llvm::ErrorInfo<TransportEOFError> { static char ID; TransportEOFError() = default; - - void log(llvm::raw_ostream &OS) const override { - OS << "transport end of file reached"; - } + void log(llvm::raw_ostream &OS) const override { OS << "transport EOF"; } std::error_code convertToErrorCode() const override { - return llvm::inconvertibleErrorCode(); + return std::make_error_code(std::errc::io_error); } }; -class TransportTimeoutError : public llvm::ErrorInfo<TransportTimeoutError> { +class TransportUnhandledContentsError + : public llvm::ErrorInfo<TransportUnhandledContentsError> { public: static char ID; - TransportTimeoutError() = default; + explicit TransportUnhandledContentsError(std::string unhandled_contents) + : m_unhandled_contents(unhandled_contents) {} void log(llvm::raw_ostream &OS) const override { - OS << "transport operation timed out"; + OS << "transport EOF with unhandled contents " << m_unhandled_contents; } std::error_code convertToErrorCode() const override { - return std::make_error_code(std::errc::timed_out); + return std::make_error_code(std::errc::bad_message); } + + const std::string &getUnhandledContents() const { + return m_unhandled_contents; + } + +private: + std::string m_unhandled_contents; }; class TransportInvalidError : public llvm::ErrorInfo<TransportInvalidError> { @@ -68,6 +76,10 @@ class TransportInvalidError : public llvm::ErrorInfo<TransportInvalidError> { /// A transport class that uses JSON for communication. class JSONTransport { public: + using ReadHandleUP = MainLoopBase::ReadHandleUP; + template <typename T> + using Callback = std::function<void(MainLoopBase &, const llvm::Expected<T>)>; + JSONTransport(lldb::IOObjectSP input, lldb::IOObjectSP output); virtual ~JSONTransport() = default; @@ -83,24 +95,69 @@ class JSONTransport { return WriteImpl(message); } - /// Reads the next message from the input stream. + /// Registers the transport with the MainLoop. template <typename T> - llvm::Expected<T> Read(const std::chrono::microseconds &timeout) { - llvm::Expected<std::string> message = ReadImpl(timeout); - if (!message) - return message.takeError(); - return llvm::json::parse<T>(/*JSON=*/*message); + llvm::Expected<ReadHandleUP> RegisterReadObject(MainLoopBase &loop, + Callback<T> callback) { + Status error; + ReadHandleUP handle = loop.RegisterReadObject( + m_input, + [&](MainLoopBase &loop) { + char buffer[kReadBufferSize]; + size_t len = sizeof(buffer); + if (llvm::Error error = m_input->Read(buffer, len).takeError()) { + callback(loop, std::move(error)); + return; + } + + if (len) + m_buffer.append(std::string(buffer, len)); + + // If the buffer has contents, try parsing any pending messages. + if (!m_buffer.empty()) { + llvm::Expected<std::vector<std::string>> messages = Parse(); + if (llvm::Error error = messages.takeError()) { + callback(loop, std::move(error)); + return; + } + + for (const auto &message : *messages) + if constexpr (std::is_same<T, std::string>::value) + callback(loop, message); + else + callback(loop, llvm::json::parse<T>(message)); + } + + // On EOF, notify the callback after the remaining messages were + // handled. + if (len == 0) { + if (m_buffer.empty()) + callback(loop, llvm::make_error<TransportEOFError>()); + else + callback(loop, llvm::make_error<TransportUnhandledContentsError>( + m_buffer)); + } + }, + error); + if (error.Fail()) + return error.takeError(); + return handle; } protected: + template <typename... Ts> inline auto Logv(const char *Fmt, Ts &&...Vals) { + Log(llvm::formatv(Fmt, std::forward<Ts>(Vals)...).str()); + } virtual void Log(llvm::StringRef message); virtual llvm::Error WriteImpl(const std::string &message) = 0; - virtual llvm::Expected<std::string> - ReadImpl(const std::chrono::microseconds &timeout) = 0; + virtual llvm::Expected<std::vector<std::string>> Parse() = 0; lldb::IOObjectSP m_input; lldb::IOObjectSP m_output; + std::string m_buffer; + + static constexpr size_t kReadBufferSize = 1024; }; /// A transport class for JSON with a HTTP header. @@ -111,14 +168,13 @@ class HTTPDelimitedJSONTransport : public JSONTransport { virtual ~HTTPDelimitedJSONTransport() = default; protected: - virtual llvm::Error WriteImpl(const std::string &message) override; - virtual llvm::Expected<std::string> - ReadImpl(const std::chrono::microseconds &timeout) override; - - // FIXME: Support any header. - static constexpr llvm::StringLiteral kHeaderContentLength = - "Content-Length: "; - static constexpr llvm::StringLiteral kHeaderSeparator = "\r\n\r\n"; + llvm::Error WriteImpl(const std::string &message) override; + llvm::Expected<std::vector<std::string>> Parse() override; + + static constexpr llvm::StringLiteral kHeaderContentLength = "Content-Length"; + static constexpr llvm::StringLiteral kHeaderFieldSeparator = ":"; + static constexpr llvm::StringLiteral kHeaderSeparator = "\r\n"; + static constexpr llvm::StringLiteral kEndOfHeader = "\r\n\r\n"; }; /// A transport class for JSON RPC. @@ -129,9 +185,8 @@ class JSONRPCTransport : public JSONTransport { virtual ~JSONRPCTransport() = default; protected: - virtual llvm::Error WriteImpl(const std::string &message) override; - virtual llvm::Expected<std::string> - ReadImpl(const std::chrono::microseconds &timeout) override; + llvm::Error WriteImpl(const std::string &message) override; + llvm::Expected<std::vector<std::string>> Parse() override; static constexpr llvm::StringLiteral kMessageSeparator = "\n"; }; diff --git a/lldb/source/Host/common/JSONTransport.cpp b/lldb/source/Host/common/JSONTransport.cpp index 546c12c8f7114..c3a3b06ecbced 100644 --- a/lldb/source/Host/common/JSONTransport.cpp +++ b/lldb/source/Host/common/JSONTransport.cpp @@ -7,17 +7,14 @@ //===----------------------------------------------------------------------===// #include "lldb/Host/JSONTransport.h" -#include "lldb/Utility/IOObject.h" #include "lldb/Utility/LLDBLog.h" #include "lldb/Utility/Log.h" -#include "lldb/Utility/SelectHelper.h" #include "lldb/Utility/Status.h" #include "lldb/lldb-forward.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Error.h" #include "llvm/Support/raw_ostream.h" -#include <optional> #include <string> #include <utility> @@ -25,64 +22,6 @@ using namespace llvm; using namespace lldb; using namespace lldb_private; -/// ReadFull attempts to read the specified number of bytes. If EOF is -/// encountered, an empty string is returned. -static Expected<std::string> -ReadFull(IOObject &descriptor, size_t length, - std::optional<std::chrono::microseconds> timeout = std::nullopt) { - if (!descriptor.IsValid()) - return llvm::make_error<TransportInvalidError>(); - - bool timeout_supported = true; - // FIXME: SelectHelper does not work with NativeFile on Win32. -#if _WIN32 - timeout_supported = descriptor.GetFdType() == IOObject::eFDTypeSocket; -#endif - - if (timeout && timeout_supported) { - SelectHelper sh; - sh.SetTimeout(*timeout); - sh.FDSetRead( - reinterpret_cast<lldb::socket_t>(descriptor.GetWaitableHandle())); - Status status = sh.Select(); - if (status.Fail()) { - // Convert timeouts into a specific error. - if (status.GetType() == lldb::eErrorTypePOSIX && - status.GetError() == ETIMEDOUT) - return make_error<TransportTimeoutError>(); - return status.takeError(); - } - } - - std::string data; - data.resize(length); - Status status = descriptor.Read(data.data(), length); - if (status.Fail()) - return status.takeError(); - - // Read returns '' on EOF. - if (length == 0) - return make_error<TransportEOFError>(); - - // Return the actual number of bytes read. - return data.substr(0, length); -} - -static Expected<std::string> -ReadUntil(IOObject &descriptor, StringRef delimiter, - std::optional<std::chrono::microseconds> timeout = std::nullopt) { - std::string buffer; - buffer.reserve(delimiter.size() + 1); - while (!llvm::StringRef(buffer).ends_with(delimiter)) { - Expected<std::string> next = - ReadFull(descriptor, buffer.empty() ? delimiter.size() : 1, timeout); - if (auto Err = next.takeError()) - return std::move(Err); - buffer += *next; - } - return buffer.substr(0, buffer.size() - delimiter.size()); -} - JSONTransport::JSONTransport(IOObjectSP input, IOObjectSP output) : m_input(std::move(input)), m_output(std::move(output)) {} @@ -90,80 +29,80 @@ void JSONTransport::Log(llvm::StringRef message) { LLDB_LOG(GetLog(LLDBLog::Host), "{0}", message); } -Expected<std::string> -HTTPDelimitedJSONTransport::ReadImpl(const std::chrono::microseconds &timeout) { - if (!m_input || !m_input->IsValid()) - return llvm::make_error<TransportInvalidError>(); +// Parses messages based on +// https://microsoft.github.io/debug-adapter-protocol/overview#base-protocol +Expected<std::vector<std::string>> HTTPDelimitedJSONTransport::Parse() { + std::vector<std::string> messages; + StringRef buffer = m_buffer; + while (buffer.contains(kEndOfHeader)) { + auto [headers, rest] = buffer.split(kEndOfHeader); + size_t content_length = 0; + // HTTP Headers are formatted like `<field-name> ':' [<field-value>]`. + for (const auto &header : llvm::split(headers, kHeaderSeparator)) { + auto [key, value] = header.split(kHeaderFieldSeparator); + // 'Content-Length' is the only meaningful key at the moment. Others are + // ignored. + if (!key.equals_insensitive(kHeaderContentLength)) + continue; + + value = value.trim(); + if (!llvm::to_integer(value, content_length, 10)) + return createStringError(std::errc::invalid_argument, + "invalid content length: %s", + value.str().c_str()); + } + + // Check if we have enough data. + if (content_length > rest.size()) + break; - IOObject *input = m_input.get(); - Expected<std::string> message_header = - ReadFull(*input, kHeaderContentLength.size(), timeout); - if (!message_header) - return message_header.takeError(); - if (*message_header != kHeaderContentLength) - return createStringError(formatv("expected '{0}' and got '{1}'", - kHeaderContentLength, *message_header) - .str()); - - Expected<std::string> raw_length = ReadUntil(*input, kHeaderSeparator); - if (!raw_length) - return handleErrors(raw_length.takeError(), - [&](const TransportEOFError &E) -> llvm::Error { - return createStringError( - "unexpected EOF while reading header separator"); - }); - - size_t length; - if (!to_integer(*raw_length, length)) - return createStringError( - formatv("invalid content length {0}", *raw_length).str()); - - Expected<std::string> raw_json = ReadFull(*input, length); - if (!raw_json) - return handleErrors( - raw_json.takeError(), [&](const TransportEOFError &E) -> llvm::Error { - return createStringError("unexpected EOF while reading JSON"); - }); - - Log(llvm::formatv("--> {0}", *raw_json).str()); - - return raw_json; + StringRef body = rest.take_front(content_length); + buffer = rest.drop_front(content_length); + messages.emplace_back(body.str()); + Logv("--> {0}", body); + } + + // Store the remainder of the buffer for the next read callback. + m_buffer = buffer.str(); + + return std::move(messages); } Error HTTPDelimitedJSONTransport::WriteImpl(const std::string &message) { if (!m_output || !m_output->IsValid()) return llvm::make_error<TransportInvalidError>(); - Log(llvm::formatv("<-- {0}", message).str()); + Logv("<-- {0}", message); std::string Output; raw_string_ostream OS(Output); - OS << kHeaderContentLength << message.length() << kHeaderSeparator << message; + OS << kHeaderContentLength << kHeaderFieldSeparator << ' ' << message.length() + << kHeaderSeparator << kHeaderSeparator << message; size_t num_bytes = Output.size(); return m_output->Write(Output.data(), num_bytes).takeError(); } -Expected<std::string> -JSONRPCTransport::ReadImpl(const std::chrono::microseconds &timeout) { - if (!m_input || !m_input->IsValid()) - return make_error<TransportInvalidError>(); - - IOObject *input = m_input.get(); - Expected<std::string> raw_json = - ReadUntil(*input, kMessageSeparator, timeout); - if (!raw_json) - return raw_json.takeError(); +Expected<std::vector<std::string>> JSONRPCTransport::Parse() { + std::vector<std::string> messages; + StringRef buf = m_buffer; + while (buf.contains(kMessageSeparator)) { + auto [raw_json, rest] = buf.split(kMessageSeparator); + buf = rest; + messages.emplace_back(raw_json.str()); + Logv("--> {0}", raw_json); + } - Log(llvm::formatv("--> {0}", *raw_json).str()); + // Store the remainder of the buffer for the next read callback. + m_buffer = buf.str(); - return *raw_json; + return messages; } Error JSONRPCTransport::WriteImpl(const std::string &message) { if (!m_output || !m_output->IsValid()) return llvm::make_error<TransportInvalidError>(); - Log(llvm::formatv("<-- {0}", message).str()); + Logv("<-- {0}", message); std::string Output; llvm::raw_string_ostream OS(Output); @@ -173,5 +112,5 @@ Error JSONRPCTransport::WriteImpl(const std::string &message) { } char TransportEOFError::ID; -char TransportTimeoutError::ID; +char TransportUnhandledContentsError::ID; char TransportInvalidError::ID; diff --git a/lldb/test/API/tools/lldb-dap/io/TestDAP_io.py b/lldb/test/API/tools/lldb-dap/io/TestDAP_io.py index b72b98de412b4..af5c62a8c4eb5 100644 --- a/lldb/test/API/tools/lldb-dap/io/TestDAP_io.py +++ b/lldb/test/API/tools/lldb-dap/io/TestDAP_io.py @@ -8,6 +8,9 @@ import lldbdap_testcase import dap_server +EXIT_FAILURE = 1 +EXIT_SUCCESS = 0 + class TestDAP_io(lldbdap_testcase.DAPTestCaseBase): def launch(self): @@ -41,40 +44,44 @@ def test_eof_immediately(self): """ process = self.launch() process.stdin.close() - self.assertEqual(process.wait(timeout=5.0), 0) + self.assertEqual(process.wait(timeout=5.0), EXIT_SUCCESS) def test_invalid_header(self): """ - lldb-dap handles invalid message headers. + lldb-dap returns a failure exit code when the input stream is closed + with a malformed request header. """ process = self.launch() - process.stdin.write(b"not the corret message header") + process.stdin.write(b"not the correct message header") process.stdin.close() - self.assertEqual(process.wait(timeout=5.0), 1) + self.assertEqual(process.wait(timeout=5.0), EXIT_FAILURE) def test_partial_header(self): """ - lldb-dap handles parital message headers. + lldb-dap returns a failure exit code when the input stream is closed + with an incomplete message header is in the message buffer. """ process = self.launch() process.stdin.write(b"Content-Length: ") process.stdin.close() - self.assertEqual(process.wait(timeout=5.0), 1) + self.assertEqual(process.wait(timeout=5.0), EXIT_FAILURE) def test_incorrect_content_length(self): """ - lldb-dap handles malformed content length headers. + lldb-dap returns a failure exit code when reading malformed content + length headers. """ process = self.launch() process.stdin.write(b"Content-Length: abc") process.stdin.close() - self.assertEqual(process.wait(timeout=5.0), 1) + self.assertEqual(process.wait(timeout=5.0), EXIT_FAILURE) def test_partial_content_length(self): """ - lldb-dap handles partial messages. + lldb-dap returns a failure exit code when the input stream is closed + with a partial message in the message buffer. """ process = self.launch() process.stdin.write(b"Content-Length: 10\r\n\r\n{") process.stdin.close() - self.assertEqual(process.wait(timeout=5.0), 1) + self.assertEqual(process.wait(timeout=5.0), EXIT_FAILURE) diff --git a/lldb/tools/lldb-dap/DAP.cpp b/lldb/tools/lldb-dap/DAP.cpp index cbd3b14463e25..55c5c9347bf6f 100644 --- a/lldb/tools/lldb-dap/DAP.cpp +++ b/lldb/tools/lldb-dap/DAP.cpp @@ -23,13 +23,14 @@ #include "Transport.h" #include "lldb/API/SBBreakpoint.h" #include "lldb/API/SBCommandInterpreter.h" -#include "lldb/API/SBCommandReturnObject.h" #include "lldb/API/SBEvent.h" #include "lldb/API/SBLanguageRuntime.h" #include "lldb/API/SBListener.h" #include "lldb/API/SBProcess.h" #include "lldb/API/SBStream.h" -#include "lldb/Utility/IOObject.h" +#include "lldb/Host/JSONTransport.h" +#include "lldb/Host/MainLoop.h" +#include "lldb/Host/MainLoopBase.h" #include "lldb/Utility/Status.h" #include "lldb/lldb-defines.h" #include "lldb/lldb-enumerations.h" @@ -52,7 +53,7 @@ #include <cstdarg> #include <cstdint> #include <cstdio> -#include <fstream> +#include <functional> #include <future> #include <memory> #include <mutex> @@ -919,6 +920,8 @@ llvm::Error DAP::Disconnect(bool terminateDebuggee) { SendTerminatedEvent(); disconnecting = true; + m_loop.AddPendingCallback( + [](MainLoopBase &loop) { loop.RequestTermination(); }); return ToError(error); } @@ -949,75 +952,74 @@ static std::optional<T> getArgumentsIfRequest(const Message &pm, return args; } -llvm::Error DAP::Loop() { - // Can't use \a std::future<llvm::Error> because it doesn't compile on - // Windows. - std::future<lldb::SBError> queue_reader = - std::async(std::launch::async, [&]() -> lldb::SBError { - llvm::set_thread_name(transport.GetClientName() + ".transport_handler"); - auto cleanup = llvm::make_scope_exit([&]() { - // Ensure we're marked as disconnecting when the reader exits. - disconnecting = true; - m_queue_cv.notify_all(); - }); - - while (!disconnecting) { - llvm::Expected<Message> next = - transport.Read<protocol::Message>(std::chrono::seconds(1)); - if (next.errorIsA<TransportEOFError>()) { - consumeError(next.takeError()); - break; - } +Status DAP::TransportHandler() { + llvm::set_thread_name(transport.GetClientName() + ".transport_handler"); - // If the read timed out, continue to check if we should disconnect. - if (next.errorIsA<TransportTimeoutError>()) { - consumeError(next.takeError()); - continue; - } + auto cleanup = llvm::make_scope_exit([&]() { + // Ensure we're marked as disconnecting when the reader exits. + disconnecting = true; + m_queue_cv.notify_all(); + }); - if (llvm::Error err = next.takeError()) { - lldb::SBError errWrapper; - errWrapper.SetErrorString(llvm::toString(std::move(err)).c_str()); - return errWrapper; - } + Status status; + auto handle = transport.RegisterReadObject<protocol::Message>( + m_loop, + [&](MainLoopBase &loop, llvm::Expected<protocol::Message> message) { + if (message.errorIsA<TransportEOFError>()) { + llvm::consumeError(message.takeError()); + loop.RequestTermination(); + return; + } - if (const protocol::Request *req = - std::get_if<protocol::Request>(&*next); - req && req->command == "disconnect") - disconnecting = true; - - const std::optional<CancelArguments> cancel_args = - getArgumentsIfRequest<CancelArguments>(*next, "cancel"); - if (cancel_args) { - { - std::lock_guard<std::mutex> guard(m_cancelled_requests_mutex); - if (cancel_args->requestId) - m_cancelled_requests.insert(*cancel_args->requestId); - } + if (llvm::Error err = message.takeError()) { + status = Status::FromError(std::move(err)); + loop.RequestTermination(); + return; + } - // If a cancel is requested for the active request, make a best - // effort attempt to interrupt. - std::lock_guard<std::mutex> guard(m_active_request_mutex); - if (m_active_request && - cancel_args->requestId == m_active_request->seq) { - DAP_LOG( - log, - "({0}) interrupting inflight request (command={1} seq={2})", - transport.GetClientName(), m_active_request->command, - m_active_request->seq); - debugger.RequestInterrupt(); - } - } + if (const protocol::Request *req = + std::get_if<protocol::Request>(&*message); + req && req->arguments == "disconnect") + disconnecting = true; + const std::optional<CancelArguments> cancel_args = + getArgumentsIfRequest<CancelArguments>(*message, "cancel"); + if (cancel_args) { { - std::lock_guard<std::mutex> guard(m_queue_mutex); - m_queue.push_back(std::move(*next)); + std::lock_guard<std::mutex> guard(m_cancelled_requests_mutex); + if (cancel_args->requestId) + m_cancelled_requests.insert(*cancel_args->requestId); + } + + // If a cancel is requested for the active request, make a best + // effort attempt to interrupt. + std::lock_guard<std::mutex> guard(m_active_request_mutex); + if (m_active_request && + cancel_args->requestId == m_active_request->seq) { + DAP_LOG(log, + "({0}) interrupting inflight request (command={1} seq={2})", + transport.GetClientName(), m_active_request->command, + m_active_request->seq); + debugger.RequestInterrupt(); } - m_queue_cv.notify_one(); } - return lldb::SBError(); + std::lock_guard<std::mutex> guard(m_queue_mutex); + m_queue.push_back(std::move(*message)); + m_queue_cv.notify_one(); }); + if (auto err = handle.takeError()) + return Status::FromError(std::move(err)); + if (llvm::Error err = m_loop.Run().takeError()) + return Status::FromError(std::move(err)); + return status; +} + +llvm::Error DAP::Loop() { + // Can't use \a std::future<llvm::Error> because it doesn't compile on + // Windows. + std::future<Status> queue_reader = + std::async(std::launch::async, &DAP::TransportHandler, this); auto cleanup = llvm::make_scope_exit([&]() { out.Stop(); @@ -1043,7 +1045,7 @@ llvm::Error DAP::Loop() { "unhandled packet"); } - return ToError(queue_reader.get()); + return queue_reader.get().takeError(); } lldb::SBError DAP::WaitForProcessToStop(std::chrono::seconds seconds) { diff --git a/lldb/tools/lldb-dap/DAP.h b/lldb/tools/lldb-dap/DAP.h index af4aabaafaae8..b0e9fa9c16b75 100644 --- a/lldb/tools/lldb-dap/DAP.h +++ b/lldb/tools/lldb-dap/DAP.h @@ -31,6 +31,8 @@ #include "lldb/API/SBMutex.h" #include "lldb/API/SBTarget.h" #include "lldb/API/SBThread.h" +#include "lldb/Host/MainLoop.h" +#include "lldb/Utility/Status.h" #include "lldb/lldb-types.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" @@ -424,6 +426,8 @@ struct DAP { const std::optional<std::vector<protocol::SourceBreakpoint>> &breakpoints, SourceBreakpointMap &existing_breakpoints); + lldb_private::Status TransportHandler(); + /// Registration of request handler. /// @{ void RegisterRequests(); @@ -451,6 +455,9 @@ struct DAP { std::mutex m_queue_mutex; std::condition_variable m_queue_cv; + // Loop for managing reading from the client. + lldb_private::MainLoop m_loop; + std::mutex m_cancelled_requests_mutex; llvm::SmallSet<int64_t, 4> m_cancelled_requests; diff --git a/lldb/tools/lldb-dap/Transport.h b/lldb/tools/lldb-dap/Transport.h index 51f62e718a0d0..9a7d8f424d40e 100644 --- a/lldb/tools/lldb-dap/Transport.h +++ b/lldb/tools/lldb-dap/Transport.h @@ -29,7 +29,7 @@ class Transport : public lldb_private::HTTPDelimitedJSONTransport { lldb::IOObjectSP input, lldb::IOObjectSP output); virtual ~Transport() = default; - virtual void Log(llvm::StringRef message) override; + void Log(llvm::StringRef message) override; /// Returns the name of this transport client, for example `stdin/stdout` or /// `client_1`. diff --git a/lldb/unittests/DAP/DAPTest.cpp b/lldb/unittests/DAP/DAPTest.cpp index 40ffaf87c9c45..138910d917424 100644 --- a/lldb/unittests/DAP/DAPTest.cpp +++ b/lldb/unittests/DAP/DAPTest.cpp @@ -9,10 +9,8 @@ #include "DAP.h" #include "Protocol/ProtocolBase.h" #include "TestBase.h" -#include "Transport.h" #include "llvm/Testing/Support/Error.h" #include "gtest/gtest.h" -#include <chrono> #include <memory> #include <optional> @@ -32,8 +30,9 @@ TEST_F(DAPTest, SendProtocolMessages) { /*transport=*/*to_dap, }; dap.Send(Event{/*event=*/"my-event", /*body=*/std::nullopt}); - ASSERT_THAT_EXPECTED( - from_dap->Read<protocol::Message>(std::chrono::milliseconds(1)), - HasValue(testing::VariantWith<Event>(testing::FieldsAre( - /*event=*/"my-event", /*body=*/std::nullopt)))); + RunOnce<protocol::Message>([&](llvm::Expected<protocol::Message> message) { + ASSERT_THAT_EXPECTED( + message, HasValue(testing::VariantWith<Event>(testing::FieldsAre( + /*event=*/"my-event", /*body=*/std::nullopt)))); + }); } diff --git a/lldb/unittests/DAP/TestBase.cpp b/lldb/unittests/DAP/TestBase.cpp index 94b9559b9ca70..8f9b098c8b1e1 100644 --- a/lldb/unittests/DAP/TestBase.cpp +++ b/lldb/unittests/DAP/TestBase.cpp @@ -12,9 +12,11 @@ #include "lldb/API/SBDefines.h" #include "lldb/API/SBStructuredData.h" #include "lldb/Host/File.h" +#include "lldb/Host/MainLoop.h" #include "lldb/Host/Pipe.h" #include "lldb/lldb-forward.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/Error.h" #include "llvm/Testing/Support/Error.h" #include "gtest/gtest.h" #include <memory> @@ -25,6 +27,8 @@ using namespace lldb_dap; using namespace lldb_dap::protocol; using namespace lldb_dap_tests; using lldb_private::File; +using lldb_private::MainLoop; +using lldb_private::MainLoopBase; using lldb_private::NativeFile; using lldb_private::Pipe; @@ -118,14 +122,18 @@ void DAPTestBase::LoadCore() { std::vector<Message> DAPTestBase::DrainOutput() { std::vector<Message> msgs; output.CloseWriteFileDescriptor(); - while (true) { - Expected<Message> next = - from_dap->Read<protocol::Message>(std::chrono::milliseconds(1)); - if (!next) { - consumeError(next.takeError()); - break; - } - msgs.push_back(*next); - } + auto handle = from_dap->RegisterReadObject<protocol::Message>( + loop, [&](MainLoopBase &loop, Expected<protocol::Message> next) { + if (llvm::Error error = next.takeError()) { + loop.RequestTermination(); + consumeError(std::move(error)); + return; + } + + msgs.push_back(*next); + }); + + consumeError(handle.takeError()); + consumeError(loop.Run().takeError()); return msgs; } diff --git a/lldb/unittests/DAP/TestBase.h b/lldb/unittests/DAP/TestBase.h index 50884b1d7feb9..50d069e401741 100644 --- a/lldb/unittests/DAP/TestBase.h +++ b/lldb/unittests/DAP/TestBase.h @@ -10,6 +10,7 @@ #include "Protocol/ProtocolBase.h" #include "TestingSupport/Host/PipeTestUtilities.h" #include "Transport.h" +#include "lldb/Host/MainLoop.h" #include "llvm/ADT/StringRef.h" #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -22,8 +23,27 @@ class TransportBase : public PipePairTest { protected: std::unique_ptr<lldb_dap::Transport> to_dap; std::unique_ptr<lldb_dap::Transport> from_dap; + lldb_private::MainLoop loop; void SetUp() override; + + template <typename P> + void RunOnce(const std::function<void(llvm::Expected<P>)> &callback, + std::chrono::milliseconds timeout = std::chrono::seconds(1)) { + auto handle = from_dap->RegisterReadObject<P>( + loop, [&](lldb_private::MainLoopBase &loop, llvm::Expected<P> message) { + callback(std::move(message)); + loop.RequestTermination(); + }); + loop.AddCallback( + [&](lldb_private::MainLoopBase &loop) { + loop.RequestTermination(); + FAIL() << "timeout waiting for read callback"; + }, + timeout); + ASSERT_THAT_EXPECTED(handle, llvm::Succeeded()); + ASSERT_THAT_ERROR(loop.Run().takeError(), llvm::Succeeded()); + } }; /// Matches an "output" event. diff --git a/lldb/unittests/Host/JSONTransportTest.cpp b/lldb/unittests/Host/JSONTransportTest.cpp index 2f0846471688c..6eae400aa4f6b 100644 --- a/lldb/unittests/Host/JSONTransportTest.cpp +++ b/lldb/unittests/Host/JSONTransportTest.cpp @@ -1,4 +1,4 @@ -//===-- JSONTransportTest.cpp ---------------------------------------------===// +//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -9,14 +9,42 @@ #include "lldb/Host/JSONTransport.h" #include "TestingSupport/Host/PipeTestUtilities.h" #include "lldb/Host/File.h" +#include "lldb/Host/MainLoop.h" +#include "lldb/Host/MainLoopBase.h" +#include "llvm/ADT/FunctionExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/JSON.h" +#include "llvm/Testing/Support/Error.h" +#include "gtest/gtest.h" +#include <chrono> +#include <cstddef> +#include <future> +#include <memory> +#include <string> using namespace llvm; using namespace lldb_private; namespace { + +struct JSONTestType { + std::string str; +}; + +json::Value toJSON(const JSONTestType &T) { + return json::Object{{"str", T.str}}; +} + +bool fromJSON(const json::Value &V, JSONTestType &T, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("str", T.str); +} + template <typename T> class JSONTransportTest : public PipePairTest { protected: std::unique_ptr<JSONTransport> transport; + MainLoop loop; void SetUp() override { PipePairTest::SetUp(); @@ -28,68 +56,193 @@ template <typename T> class JSONTransportTest : public PipePairTest { File::eOpenOptionWriteOnly, NativeFile::Unowned)); } + + template <typename P> + Expected<P> + RunOnce(std::chrono::milliseconds timeout = std::chrono::seconds(1)) { + std::promise<Expected<P>> promised_message; + std::future<Expected<P>> future_message = promised_message.get_future(); + RunUntil<P>( + [&](Expected<P> message) mutable -> bool { + promised_message.set_value(std::move(message)); + return /*keep_going*/ false; + }, + timeout); + return future_message.get(); + } + + /// RunUntil runs the event loop until the callback returns `false` or a + /// timeout has occurred. + template <typename P> + void RunUntil(std::function<bool(Expected<P>)> callback, + std::chrono::milliseconds timeout = std::chrono::seconds(1)) { + auto handle = transport->RegisterReadObject<P>( + loop, [&](MainLoopBase &loop, Expected<P> message) mutable { + bool keep_going = callback(std::move(message)); + if (!keep_going) + loop.RequestTermination(); + }); + loop.AddCallback( + [&](MainLoopBase &loop) mutable { + loop.RequestTermination(); + callback(createStringError("timeout")); + }, + timeout); + EXPECT_THAT_EXPECTED(handle, Succeeded()); + EXPECT_THAT_ERROR(loop.Run().takeError(), Succeeded()); + } + + template <typename... Ts> llvm::Expected<size_t> Write(Ts... args) { + std::string message; + for (const auto &arg : {args...}) + message += Encode(arg); + return input.Write(message.data(), message.size()); + } + + virtual std::string Encode(const json::Value &) = 0; }; class HTTPDelimitedJSONTransportTest : public JSONTransportTest<HTTPDelimitedJSONTransport> { public: using JSONTransportTest::JSONTransportTest; + + std::string Encode(const json::Value &V) override { + std::string msg; + raw_string_ostream OS(msg); + OS << formatv("{0}", V); + return formatv("Content-Length: {0}\r\nContent-type: " + "text/json\r\n\r\n{1}", + msg.size(), msg) + .str(); + } }; class JSONRPCTransportTest : public JSONTransportTest<JSONRPCTransport> { public: using JSONTransportTest::JSONTransportTest; -}; -struct JSONTestType { - std::string str; + std::string Encode(const json::Value &V) override { + std::string msg; + raw_string_ostream OS(msg); + OS << formatv("{0}\n", V); + return msg; + } }; -llvm::json::Value toJSON(const JSONTestType &T) { - return llvm::json::Object{{"str", T.str}}; -} - -bool fromJSON(const llvm::json::Value &V, JSONTestType &T, llvm::json::Path P) { - llvm::json::ObjectMapper O(V, P); - return O && O.map("str", T.str); -} } // namespace TEST_F(HTTPDelimitedJSONTransportTest, MalformedRequests) { - std::string malformed_header = "COnTent-LenGth: -1{}\r\n\r\nnotjosn"; + std::string malformed_header = + "COnTent-LenGth: -1\r\nContent-Type: text/json\r\n\r\nnotjosn"; ASSERT_THAT_EXPECTED( input.Write(malformed_header.data(), malformed_header.size()), Succeeded()); - ASSERT_THAT_EXPECTED( - transport->Read<JSONTestType>(std::chrono::milliseconds(1)), - FailedWithMessage( - "expected 'Content-Length: ' and got 'COnTent-LenGth: '")); + ASSERT_THAT_EXPECTED(RunOnce<JSONTestType>(), + FailedWithMessage("invalid content length: -1")); } TEST_F(HTTPDelimitedJSONTransportTest, Read) { - std::string json = R"json({"str": "foo"})json"; - std::string message = - formatv("Content-Length: {0}\r\n\r\n{1}", json.size(), json).str(); - ASSERT_THAT_EXPECTED(input.Write(message.data(), message.size()), + ASSERT_THAT_EXPECTED(Write(JSONTestType{"foo"}), Succeeded()); + ASSERT_THAT_EXPECTED(RunOnce<JSONTestType>(), + HasValue(testing::FieldsAre(/*str=*/"foo"))); +} + +TEST_F(HTTPDelimitedJSONTransportTest, ReadMultipleMessagesInSingleWrite) { + ASSERT_THAT_EXPECTED(Write(JSONTestType{"one"}, JSONTestType{"two"}), Succeeded()); + unsigned count = 0; + RunUntil<JSONTestType>([&](Expected<JSONTestType> message) -> bool { + if (count == 0) { + EXPECT_THAT_EXPECTED(message, + HasValue(testing::FieldsAre(/*str=*/"one"))); + } else if (count == 1) { + EXPECT_THAT_EXPECTED(message, + HasValue(testing::FieldsAre(/*str=*/"two"))); + } + + count++; + return count < 2; + }); +} + +TEST_F(HTTPDelimitedJSONTransportTest, ReadAcrossMultipleChunks) { + std::string long_str = std::string(2048, 'x'); + ASSERT_THAT_EXPECTED(Write(JSONTestType{long_str}), Succeeded()); + ASSERT_THAT_EXPECTED(RunOnce<JSONTestType>(), + HasValue(testing::FieldsAre(/*str=*/long_str))); +} + +TEST_F(HTTPDelimitedJSONTransportTest, ReadPartialMessage) { + std::string message = Encode(JSONTestType{"foo"}); + std::string part1 = message.substr(0, 28); + std::string part2 = message.substr(28); + + ASSERT_THAT_EXPECTED(input.Write(part1.data(), part1.size()), Succeeded()); + ASSERT_THAT_EXPECTED( - transport->Read<JSONTestType>(std::chrono::milliseconds(1)), - HasValue(testing::FieldsAre(/*str=*/"foo"))); + RunOnce<JSONTestType>(/*timeout=*/std::chrono::milliseconds(10)), + FailedWithMessage("timeout")); + + ASSERT_THAT_EXPECTED(input.Write(part2.data(), part2.size()), Succeeded()); + + ASSERT_THAT_EXPECTED(RunOnce<JSONTestType>(), + HasValue(testing::FieldsAre(/*str=*/"foo"))); +} + +TEST_F(HTTPDelimitedJSONTransportTest, ReadWithZeroByteWrites) { + std::string message = Encode(JSONTestType{"foo"}); + std::string part1 = message.substr(0, 28); + std::string part2 = message.substr(28); + + ASSERT_THAT_EXPECTED(input.Write(part1.data(), part1.size()), Succeeded()); + ASSERT_THAT_EXPECTED( + RunOnce<JSONTestType>(/*timeout=*/std::chrono::milliseconds(10)), + FailedWithMessage("timeout")); + + ASSERT_THAT_EXPECTED(input.Write(part1.data(), 0), + Succeeded()); // zero-byte write. + + ASSERT_THAT_EXPECTED( + RunOnce<JSONTestType>(/*timeout=*/std::chrono::milliseconds(10)), + FailedWithMessage("timeout")); + + ASSERT_THAT_EXPECTED(input.Write(part2.data(), part2.size()), Succeeded()); + + ASSERT_THAT_EXPECTED(RunOnce<JSONTestType>(), + HasValue(testing::FieldsAre(/*str=*/"foo"))); } TEST_F(HTTPDelimitedJSONTransportTest, ReadWithEOF) { input.CloseWriteFileDescriptor(); - ASSERT_THAT_EXPECTED( - transport->Read<JSONTestType>(std::chrono::milliseconds(1)), - Failed<TransportEOFError>()); + ASSERT_THAT_EXPECTED(RunOnce<JSONTestType>(), Failed<TransportEOFError>()); +} + +TEST_F(HTTPDelimitedJSONTransportTest, ReaderWithUnhandledData) { + std::string json = R"json({"str": "foo"})json"; + std::string message = + formatv("Content-Length: {0}\r\nContent-type: text/json\r\n\r\n{1}", + json.size(), json) + .str(); + // Write an incomplete message and close the handle. + ASSERT_THAT_EXPECTED(input.Write(message.data(), message.size() - 1), + Succeeded()); + input.CloseWriteFileDescriptor(); + ASSERT_THAT_EXPECTED(RunOnce<JSONTestType>(), + Failed<TransportUnhandledContentsError>()); } +TEST_F(HTTPDelimitedJSONTransportTest, NoDataTimeout) { + ASSERT_THAT_EXPECTED( + RunOnce<JSONTestType>(/*timeout=*/std::chrono::milliseconds(10)), + FailedWithMessage("timeout")); +} TEST_F(HTTPDelimitedJSONTransportTest, InvalidTransport) { transport = std::make_unique<HTTPDelimitedJSONTransport>(nullptr, nullptr); - ASSERT_THAT_EXPECTED( - transport->Read<JSONTestType>(std::chrono::milliseconds(1)), - Failed<TransportInvalidError>()); + auto handle = transport->RegisterReadObject<JSONTestType>( + loop, [&](MainLoopBase &, llvm::Expected<JSONTestType>) {}); + ASSERT_THAT_EXPECTED(handle, FailedWithMessage("IO object is not valid.")); } TEST_F(HTTPDelimitedJSONTransportTest, Write) { @@ -108,26 +261,56 @@ TEST_F(JSONRPCTransportTest, MalformedRequests) { ASSERT_THAT_EXPECTED( input.Write(malformed_header.data(), malformed_header.size()), Succeeded()); - ASSERT_THAT_EXPECTED( - transport->Read<JSONTestType>(std::chrono::milliseconds(1)), - llvm::Failed()); + ASSERT_THAT_EXPECTED(RunOnce<JSONTestType>(), llvm::Failed()); } TEST_F(JSONRPCTransportTest, Read) { - std::string json = R"json({"str": "foo"})json"; - std::string message = formatv("{0}\n", json).str(); + ASSERT_THAT_EXPECTED(Write(JSONTestType{"foo"}), Succeeded()); + ASSERT_THAT_EXPECTED(RunOnce<JSONTestType>(), + HasValue(testing::FieldsAre(/*str=*/"foo"))); +} + +TEST_F(JSONRPCTransportTest, ReadAcrossMultipleChunks) { + std::string long_str = std::string(2048, 'x'); + std::string message = Encode(JSONTestType{long_str}); ASSERT_THAT_EXPECTED(input.Write(message.data(), message.size()), Succeeded()); + ASSERT_THAT_EXPECTED(RunOnce<JSONTestType>(), + HasValue(testing::FieldsAre(/*str=*/long_str))); +} + +TEST_F(JSONRPCTransportTest, ReadPartialMessage) { + std::string message = R"({"str": "foo"})" + "\n"; + std::string part1 = message.substr(0, 7); + std::string part2 = message.substr(7); + + ASSERT_THAT_EXPECTED(input.Write(part1.data(), part1.size()), Succeeded()); + ASSERT_THAT_EXPECTED( - transport->Read<JSONTestType>(std::chrono::milliseconds(1)), - HasValue(testing::FieldsAre(/*str=*/"foo"))); + RunOnce<JSONTestType>(/*timeout=*/std::chrono::milliseconds(10)), + FailedWithMessage("timeout")); + + ASSERT_THAT_EXPECTED(input.Write(part2.data(), part2.size()), Succeeded()); + + ASSERT_THAT_EXPECTED(RunOnce<JSONTestType>(), + HasValue(testing::FieldsAre(/*str=*/"foo"))); } TEST_F(JSONRPCTransportTest, ReadWithEOF) { input.CloseWriteFileDescriptor(); - ASSERT_THAT_EXPECTED( - transport->Read<JSONTestType>(std::chrono::milliseconds(1)), - Failed<TransportEOFError>()); + ASSERT_THAT_EXPECTED(RunOnce<JSONTestType>(), Failed<TransportEOFError>()); +} + +TEST_F(JSONRPCTransportTest, ReaderWithUnhandledData) { + std::string message = R"json({"str": "foo"})json" + "\n"; + // Write an incomplete message and close the handle. + ASSERT_THAT_EXPECTED(input.Write(message.data(), message.size() - 1), + Succeeded()); + input.CloseWriteFileDescriptor(); + ASSERT_THAT_EXPECTED(RunOnce<JSONTestType>(), + Failed<TransportUnhandledContentsError>()); } TEST_F(JSONRPCTransportTest, Write) { @@ -143,39 +326,13 @@ TEST_F(JSONRPCTransportTest, Write) { TEST_F(JSONRPCTransportTest, InvalidTransport) { transport = std::make_unique<JSONRPCTransport>(nullptr, nullptr); - ASSERT_THAT_EXPECTED( - transport->Read<JSONTestType>(std::chrono::milliseconds(1)), - Failed<TransportInvalidError>()); -} - -#ifndef _WIN32 -TEST_F(HTTPDelimitedJSONTransportTest, ReadWithTimeout) { - ASSERT_THAT_EXPECTED( - transport->Read<JSONTestType>(std::chrono::milliseconds(1)), - Failed<TransportTimeoutError>()); -} - -TEST_F(JSONRPCTransportTest, ReadWithTimeout) { - ASSERT_THAT_EXPECTED( - transport->Read<JSONTestType>(std::chrono::milliseconds(1)), - Failed<TransportTimeoutError>()); -} - -// Windows CRT _read checks that the file descriptor is valid and calls a -// handler if not. This handler is normally a breakpoint, which looks like a -// crash when not handled by a debugger. -// https://learn.microsoft.com/en-us/%20cpp/c-runtime-library/reference/read?view=msvc-170 -TEST_F(HTTPDelimitedJSONTransportTest, ReadAfterClosed) { - input.CloseReadFileDescriptor(); - ASSERT_THAT_EXPECTED( - transport->Read<JSONTestType>(std::chrono::milliseconds(1)), - llvm::Failed()); + auto handle = transport->RegisterReadObject<JSONTestType>( + loop, [&](MainLoopBase &, llvm::Expected<JSONTestType>) {}); + ASSERT_THAT_EXPECTED(handle, FailedWithMessage("IO object is not valid.")); } -TEST_F(JSONRPCTransportTest, ReadAfterClosed) { - input.CloseReadFileDescriptor(); +TEST_F(JSONRPCTransportTest, NoDataTimeout) { ASSERT_THAT_EXPECTED( - transport->Read<JSONTestType>(std::chrono::milliseconds(1)), - llvm::Failed()); + RunOnce<JSONTestType>(/*timeout=*/std::chrono::milliseconds(10)), + FailedWithMessage("timeout")); } -#endif diff --git a/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp b/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp index b1cc21a5b0c37..bbda1e36cc6f1 100644 --- a/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp +++ b/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp @@ -9,16 +9,22 @@ #include "Plugins/Platform/MacOSX/PlatformRemoteMacOSX.h" #include "Plugins/Protocol/MCP/MCPError.h" #include "Plugins/Protocol/MCP/ProtocolServerMCP.h" -#include "TestingSupport/Host/SocketTestUtilities.h" #include "TestingSupport/SubsystemRAII.h" #include "lldb/Core/ProtocolServer.h" #include "lldb/Host/FileSystem.h" #include "lldb/Host/HostInfo.h" #include "lldb/Host/JSONTransport.h" +#include "lldb/Host/MainLoop.h" +#include "lldb/Host/MainLoopBase.h" #include "lldb/Host/Socket.h" +#include "lldb/Host/common/TCPSocket.h" #include "lldb/Protocol/MCP/Protocol.h" +#include "llvm/Support/Error.h" #include "llvm/Testing/Support/Error.h" #include "gtest/gtest.h" +#include <chrono> +#include <condition_variable> +#include <mutex> using namespace llvm; using namespace lldb; @@ -39,7 +45,7 @@ class TestProtocolServerMCP : public lldb_private::mcp::ProtocolServerMCP { class TestJSONTransport : public lldb_private::JSONRPCTransport { public: using JSONRPCTransport::JSONRPCTransport; - using JSONRPCTransport::ReadImpl; + using JSONRPCTransport::Parse; using JSONRPCTransport::WriteImpl; }; @@ -126,6 +132,7 @@ class ProtocolServerMCPTest : public ::testing::Test { lldb::IOObjectSP m_io_sp; std::unique_ptr<TestJSONTransport> m_transport_up; std::unique_ptr<TestProtocolServerMCP> m_server_up; + MainLoop loop; static constexpr llvm::StringLiteral k_localhost = "localhost"; @@ -133,11 +140,26 @@ class ProtocolServerMCPTest : public ::testing::Test { return m_transport_up->WriteImpl(llvm::formatv("{0}\n", message).str()); } - llvm::Expected<std::string> Read() { - return m_transport_up->ReadImpl(std::chrono::milliseconds(100)); + template <typename P> + void + RunOnce(const std::function<void(llvm::Expected<P>)> &callback, + std::chrono::milliseconds timeout = std::chrono::milliseconds(100)) { + auto handle = m_transport_up->RegisterReadObject<P>( + loop, [&](lldb_private::MainLoopBase &loop, llvm::Expected<P> message) { + callback(std::move(message)); + loop.RequestTermination(); + }); + loop.AddCallback( + [&](lldb_private::MainLoopBase &loop) { + loop.RequestTermination(); + FAIL() << "timeout waiting for read callback"; + }, + timeout); + ASSERT_THAT_EXPECTED(handle, llvm::Succeeded()); + ASSERT_THAT_ERROR(loop.Run().takeError(), llvm::Succeeded()); } - void SetUp() { + void SetUp() override { // Create a debugger. ArchSpec arch("arm64-apple-macosx-"); Platform::SetHostPlatform( @@ -169,7 +191,7 @@ class ProtocolServerMCPTest : public ::testing::Test { m_transport_up = std::make_unique<TestJSONTransport>(m_io_sp, m_io_sp); } - void TearDown() { + void TearDown() override { // Stop the server. ASSERT_THAT_ERROR(m_server_up->Stop(), llvm::Succeeded()); } @@ -184,17 +206,16 @@ TEST_F(ProtocolServerMCPTest, Intialization) { R"json( {"id":0,"jsonrpc":"2.0","result":{"capabilities":{"resources":{"listChanged":false,"subscribe":false},"tools":{"listChanged":true}},"protocolVersion":"2024-11-05","serverInfo":{"name":"lldb-mcp","version":"0.1.0"}}})json"; ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); + RunOnce<std::string>([&](llvm::Expected<std::string> response_str) { + ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); + llvm::Expected<json::Value> response_json = json::parse(*response_str); + ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); - llvm::Expected<std::string> response_str = Read(); - ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); + llvm::Expected<json::Value> expected_json = json::parse(response); + ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); - llvm::Expected<json::Value> response_json = json::parse(*response_str); - ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); - - llvm::Expected<json::Value> expected_json = json::parse(response); - ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); - - EXPECT_EQ(*response_json, *expected_json); + EXPECT_EQ(*response_json, *expected_json); + }); } TEST_F(ProtocolServerMCPTest, ToolsList) { @@ -204,17 +225,17 @@ TEST_F(ProtocolServerMCPTest, ToolsList) { R"json({"id":1,"jsonrpc":"2.0","result":{"tools":[{"description":"test tool","inputSchema":{"type":"object"},"name":"test"},{"description":"Run an lldb command.","inputSchema":{"properties":{"arguments":{"type":"string"},"debugger_id":{"type":"number"}},"required":["debugger_id"],"type":"object"},"name":"lldb_command"}]}})json"; ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); + RunOnce<std::string>([&](llvm::Expected<std::string> response_str) { + ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); - llvm::Expected<std::string> response_str = Read(); - ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); + llvm::Expected<json::Value> response_json = json::parse(*response_str); + ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); - llvm::Expected<json::Value> response_json = json::parse(*response_str); - ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); + llvm::Expected<json::Value> expected_json = json::parse(response); + ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); - llvm::Expected<json::Value> expected_json = json::parse(response); - ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); - - EXPECT_EQ(*response_json, *expected_json); + EXPECT_EQ(*response_json, *expected_json); + }); } TEST_F(ProtocolServerMCPTest, ResourcesList) { @@ -224,17 +245,17 @@ TEST_F(ProtocolServerMCPTest, ResourcesList) { R"json({"id":2,"jsonrpc":"2.0","result":{"resources":[{"description":"description","mimeType":"application/json","name":"name","uri":"lldb://foo/bar"}]}})json"; ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); + RunOnce<std::string>([&](llvm::Expected<std::string> response_str) { + ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); - llvm::Expected<std::string> response_str = Read(); - ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); - - llvm::Expected<json::Value> response_json = json::parse(*response_str); - ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); + llvm::Expected<json::Value> response_json = json::parse(*response_str); + ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); - llvm::Expected<json::Value> expected_json = json::parse(response); - ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); + llvm::Expected<json::Value> expected_json = json::parse(response); + ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); - EXPECT_EQ(*response_json, *expected_json); + EXPECT_EQ(*response_json, *expected_json); + }); } TEST_F(ProtocolServerMCPTest, ToolsCall) { @@ -244,17 +265,17 @@ TEST_F(ProtocolServerMCPTest, ToolsCall) { R"json({"id":11,"jsonrpc":"2.0","result":{"content":[{"text":"foo","type":"text"}],"isError":false}})json"; ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); + RunOnce<std::string>([&](llvm::Expected<std::string> response_str) { + ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); - llvm::Expected<std::string> response_str = Read(); - ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); + llvm::Expected<json::Value> response_json = json::parse(*response_str); + ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); - llvm::Expected<json::Value> response_json = json::parse(*response_str); - ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); + llvm::Expected<json::Value> expected_json = json::parse(response); + ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); - llvm::Expected<json::Value> expected_json = json::parse(response); - ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); - - EXPECT_EQ(*response_json, *expected_json); + EXPECT_EQ(*response_json, *expected_json); + }); } TEST_F(ProtocolServerMCPTest, ToolsCallError) { @@ -266,17 +287,17 @@ TEST_F(ProtocolServerMCPTest, ToolsCallError) { R"json({"error":{"code":-32603,"message":"error"},"id":11,"jsonrpc":"2.0"})json"; ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); + RunOnce<std::string>([&](llvm::Expected<std::string> response_str) { + ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); - llvm::Expected<std::string> response_str = Read(); - ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); - - llvm::Expected<json::Value> response_json = json::parse(*response_str); - ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); + llvm::Expected<json::Value> response_json = json::parse(*response_str); + ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); - llvm::Expected<json::Value> expected_json = json::parse(response); - ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); + llvm::Expected<json::Value> expected_json = json::parse(response); + ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); - EXPECT_EQ(*response_json, *expected_json); + EXPECT_EQ(*response_json, *expected_json); + }); } TEST_F(ProtocolServerMCPTest, ToolsCallFail) { @@ -288,17 +309,17 @@ TEST_F(ProtocolServerMCPTest, ToolsCallFail) { R"json({"id":11,"jsonrpc":"2.0","result":{"content":[{"text":"failed","type":"text"}],"isError":true}})json"; ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); + RunOnce<std::string>([&](llvm::Expected<std::string> response_str) { + ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); - llvm::Expected<std::string> response_str = Read(); - ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); - - llvm::Expected<json::Value> response_json = json::parse(*response_str); - ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); + llvm::Expected<json::Value> response_json = json::parse(*response_str); + ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); - llvm::Expected<json::Value> expected_json = json::parse(response); - ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); + llvm::Expected<json::Value> expected_json = json::parse(response); + ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); - EXPECT_EQ(*response_json, *expected_json); + EXPECT_EQ(*response_json, *expected_json); + }); } TEST_F(ProtocolServerMCPTest, NotificationInitialized) { _______________________________________________ lldb-commits mailing list lldb-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/lldb-commits