https://github.com/ashgti created https://github.com/llvm/llvm-project/pull/155315
The `lldb_protocol::mcp::Binder` class is used to craft bindings between requests and notifications to specific handlers. This supports both incoming and outgoing handlers that bind these functions to a MessageHandler and generates encoding/decoding helpers for each call. For example, see the `lldb_protocol::mcp::Server` class that has been greatly simplified. >From 81643e70e88aa9cb91932071336ae817b1b2926d Mon Sep 17 00:00:00 2001 From: John Harrison <harj...@google.com> Date: Mon, 25 Aug 2025 14:40:08 -0700 Subject: [PATCH] [lldb] Creating a new Binder helper for JSONRPC transport. The `lldb_protocol::mcp::Binder` class is used to craft bindings between requests and notifications to specific handlers. This supports both incoming and outgoing handlers that bind these functions to a MessageHandler and generates encoding/decoding helpers for each call. For example, see the `lldb_protocol::mcp::Server` class that has been greatly simplified. --- lldb/include/lldb/Protocol/MCP/Binder.h | 351 ++++++++++++++++++ lldb/include/lldb/Protocol/MCP/Protocol.h | 173 ++++++++- lldb/include/lldb/Protocol/MCP/Resource.h | 2 +- lldb/include/lldb/Protocol/MCP/Server.h | 74 ++-- lldb/include/lldb/Protocol/MCP/Tool.h | 9 +- lldb/include/lldb/Protocol/MCP/Transport.h | 50 +++ .../Protocol/MCP/ProtocolServerMCP.cpp | 20 +- .../Plugins/Protocol/MCP/ProtocolServerMCP.h | 4 +- lldb/source/Plugins/Protocol/MCP/Resource.cpp | 10 +- lldb/source/Plugins/Protocol/MCP/Resource.h | 6 +- lldb/source/Plugins/Protocol/MCP/Tool.cpp | 26 +- lldb/source/Plugins/Protocol/MCP/Tool.h | 7 +- lldb/source/Protocol/MCP/Binder.cpp | 139 +++++++ lldb/source/Protocol/MCP/CMakeLists.txt | 3 + lldb/source/Protocol/MCP/Protocol.cpp | 159 +++++++- lldb/source/Protocol/MCP/Server.cpp | 255 +++---------- lldb/source/Protocol/MCP/Transport.cpp | 113 ++++++ lldb/unittests/Protocol/ProtocolMCPTest.cpp | 10 +- .../ProtocolServer/ProtocolMCPServerTest.cpp | 78 ++-- 19 files changed, 1160 insertions(+), 329 deletions(-) create mode 100644 lldb/include/lldb/Protocol/MCP/Binder.h create mode 100644 lldb/include/lldb/Protocol/MCP/Transport.h create mode 100644 lldb/source/Protocol/MCP/Binder.cpp create mode 100644 lldb/source/Protocol/MCP/Transport.cpp diff --git a/lldb/include/lldb/Protocol/MCP/Binder.h b/lldb/include/lldb/Protocol/MCP/Binder.h new file mode 100644 index 0000000000000..f9cebd940bfcb --- /dev/null +++ b/lldb/include/lldb/Protocol/MCP/Binder.h @@ -0,0 +1,351 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLDB_PROTOCOL_MCP_BINDER_H +#define LLDB_PROTOCOL_MCP_BINDER_H + +#include "lldb/Protocol/MCP/MCPError.h" +#include "lldb/Protocol/MCP/Protocol.h" +#include "lldb/Protocol/MCP/Transport.h" +#include "lldb/Utility/Status.h" +#include "llvm/ADT/FunctionExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/JSON.h" +#include <functional> +#include <future> +#include <memory> +#include <mutex> +#include <optional> + +namespace lldb_protocol::mcp { + +template <typename T> using Callback = llvm::unique_function<T>; + +template <typename T> +using Reply = llvm::unique_function<void(llvm::Expected<T>)>; +template <typename Params, typename Result> +using OutgoingRequest = + llvm::unique_function<void(const Params &, Reply<Result>)>; +template <typename Params> +using OutgoingNotification = llvm::unique_function<void(const Params &)>; + +template <typename Params, typename Result> +llvm::Expected<Result> AsyncInvoke(lldb_private::MainLoop &loop, + OutgoingRequest<Params, Result> &fn, + const Params ¶ms) { + std::promise<llvm::Expected<Result>> result_promise; + std::future<llvm::Expected<Result>> result_future = + result_promise.get_future(); + std::thread thr([&loop, &fn, params, + result_promise = std::move(result_promise)]() mutable { + fn(params, [&loop, &result_promise](llvm::Expected<Result> result) mutable { + result_promise.set_value(std::move(result)); + loop.AddPendingCallback( + [](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); }); + }); + if (llvm::Error error = loop.Run().takeError()) + result_promise.set_value(std::move(error)); + }); + thr.join(); + return result_future.get(); +} + +/// Binder collects a table of functions that handle calls. +/// +/// The wrapper takes care of parsing/serializing responses. +class Binder { +public: + explicit Binder(MCPTransport *transport) : m_handlers(transport) {} + + Binder(const Binder &) = delete; + Binder &operator=(const Binder &) = delete; + + /// Bind a handler on transport disconnect. + template <typename ThisT, typename... ExtraArgs> + void disconnected(void (ThisT::*handler)(MCPTransport *), ThisT *_this, + ExtraArgs... extra_args) { + m_handlers.m_disconnect_handler = + std::bind(handler, _this, std::placeholders::_1, + std::forward<ExtraArgs>(extra_args)...); + } + + /// Bind a handler on error when communicating with the transport. + template <typename ThisT, typename... ExtraArgs> + void error(void (ThisT::*handler)(MCPTransport *, llvm::Error), ThisT *_this, + ExtraArgs... extra_args) { + m_handlers.m_error_handler = + std::bind(handler, _this, std::placeholders::_1, std::placeholders::_2, + std::forward<ExtraArgs>(extra_args)...); + } + + /// Bind a handler for a request. + /// e.g. Bind.request("peek", this, &ThisModule::peek); + /// Handler should be e.g. Expected<PeekResult> peek(const PeekParams&); + /// PeekParams must be JSON parsable and PeekResult must be serializable. + template <typename Result, typename Params, typename ThisT, + typename... ExtraArgs> + void request(llvm::StringLiteral method, + llvm::Expected<Result> (ThisT::*fn)(const Params &, + ExtraArgs...), + ThisT *_this, ExtraArgs... extra_args) { + assert(m_handlers.m_request_handlers.find(method) == + m_handlers.m_request_handlers.end() && + "request already bound"); + std::function<llvm::Expected<Result>(const Params &)> handler = + std::bind(fn, _this, std::placeholders::_1, + std::forward<ExtraArgs>(extra_args)...); + m_handlers.m_request_handlers[method] = + [method, handler](const Request &req, + llvm::unique_function<void(const Response &)> reply) { + Params params; + llvm::json::Path::Root root(method); + if (!fromJSON(req.params, params, root)) { + reply(Response{0, Error{eErrorCodeInvalidParams, + "invalid params for " + method.str() + + ": " + llvm::toString(root.getError()), + std::nullopt}}); + return; + } + llvm::Expected<Result> result = handler(params); + if (llvm::Error error = result.takeError()) { + Error protocol_error; + llvm::handleAllErrors( + std::move(error), + [&](const MCPError &err) { + protocol_error = err.toProtocolError(); + }, + [&](const llvm::ErrorInfoBase &err) { + protocol_error.code = MCPError::kInternalError; + protocol_error.message = err.message(); + }); + reply(Response{0, protocol_error}); + return; + } + + reply(Response{0, *result}); + }; + } + + /// Bind a handler for an async request. + /// e.g. Bind.asyncRequest("peek", this, &ThisModule::peek); + /// Handler should be e.g. `void peek(const PeekParams&, + /// Reply<Expected<PeekResult>>);` PeekParams must be JSON parsable and + /// PeekResult must be serializable. + template <typename Result, typename Params, typename... ExtraArgs> + void asyncRequest( + llvm::StringLiteral method, + std::function<void(const Params &, ExtraArgs..., Reply<Result>)> fn, + ExtraArgs... extra_args) { + assert(m_handlers.m_request_handlers.find(method) == + m_handlers.m_request_handlers.end() && + "request already bound"); + std::function<void(const Params &, Reply<Result>)> handler = std::bind( + fn, std::placeholders::_1, std::forward<ExtraArgs>(extra_args)..., + std::placeholders::_2); + m_handlers.m_request_handlers[method] = + [method, handler](const Request &req, + Callback<void(const Response &)> reply) { + Params params; + llvm::json::Path::Root root(method); + if (!fromJSON(req.params, params, root)) { + reply(Response{0, Error{eErrorCodeInvalidParams, + "invalid params for " + method.str() + + ": " + llvm::toString(root.getError()), + std::nullopt}}); + return; + } + + handler(params, [reply = std::move(reply)]( + llvm::Expected<Result> result) mutable { + if (llvm::Error error = result.takeError()) { + Error protocol_error; + llvm::handleAllErrors( + std::move(error), + [&](const MCPError &err) { + protocol_error = err.toProtocolError(); + }, + [&](const llvm::ErrorInfoBase &err) { + protocol_error.code = MCPError::kInternalError; + protocol_error.message = err.message(); + }); + reply(Response{0, protocol_error}); + return; + } + + reply(Response{0, toJSON(*result)}); + }); + }; + } + template <typename Result, typename Params, typename ThisT, + typename... ExtraArgs> + void asyncRequest(llvm::StringLiteral method, + void (ThisT::*fn)(const Params &, ExtraArgs..., + Reply<Result>), + ThisT *_this, ExtraArgs... extra_args) { + assert(m_handlers.m_request_handlers.find(method) == + m_handlers.m_request_handlers.end() && + "request already bound"); + std::function<void(const Params &, Reply<Result>)> handler = std::bind( + fn, _this, std::placeholders::_1, + std::forward<ExtraArgs>(extra_args)..., std::placeholders::_2); + m_handlers.m_request_handlers[method] = + [method, handler](const Request &req, + Callback<void(const Response &)> reply) { + Params params; + llvm::json::Path::Root root; + if (!fromJSON(req.params, params, root)) { + reply(Response{0, Error{eErrorCodeInvalidParams, + "invalid params for " + method.str(), + std::nullopt}}); + return; + } + + handler(params, [reply = std::move(reply)]( + llvm::Expected<Result> result) mutable { + if (llvm::Error error = result.takeError()) { + Error protocol_error; + llvm::handleAllErrors( + std::move(error), + [&](const MCPError &err) { + protocol_error = err.toProtocolError(); + }, + [&](const llvm::ErrorInfoBase &err) { + protocol_error.code = MCPError::kInternalError; + protocol_error.message = err.message(); + }); + reply(Response{0, protocol_error}); + return; + } + + reply(Response{0, toJSON(*result)}); + }); + }; + } + + /// Bind a handler for a notification. + /// e.g. Bind.notification("peek", this, &ThisModule::peek); + /// Handler should be e.g. void peek(const PeekParams&); + /// PeekParams must be JSON parsable. + template <typename Params, typename ThisT, typename... ExtraArgs> + void notification(llvm::StringLiteral method, + void (ThisT::*fn)(const Params &, ExtraArgs...), + ThisT *_this, ExtraArgs... extra_args) { + std::function<void(const Params &)> handler = + std::bind(fn, _this, std::placeholders::_1, + std::forward<ExtraArgs>(extra_args)...); + m_handlers.m_notification_handlers[method] = + [handler](const Notification ¬e) { + Params params; + llvm::json::Path::Root root; + if (!fromJSON(note.params, params, root)) + return; // FIXME: log error? + + handler(params); + }; + } + template <typename Params> + void notification(llvm::StringLiteral method, + std::function<void(const Params &)> handler) { + assert(m_handlers.m_notification_handlers.find(method) == + m_handlers.m_notification_handlers.end() && + "notification already bound"); + m_handlers.m_notification_handlers[method] = + [handler = std::move(handler)](const Notification ¬e) { + Params params; + llvm::json::Path::Root root; + if (!fromJSON(note.params, params, root)) + return; // FIXME: log error? + + handler(params); + }; + } + + /// Bind a function object to be used for outgoing requests. + /// e.g. OutgoingRequest<Params, Result> Edit = Bind.outgoingRequest("edit"); + /// Params must be JSON-serializable, Result must be parsable. + template <typename Params, typename Result> + OutgoingRequest<Params, Result> outgoingRequest(llvm::StringLiteral method) { + return [this, method](const Params ¶ms, Reply<Result> reply) { + Request request; + request.method = method; + request.params = toJSON(params); + m_handlers.Send(request, [reply = std::move(reply)]( + const Response &resp) mutable { + if (const lldb_protocol::mcp::Error *err = + std::get_if<lldb_protocol::mcp::Error>(&resp.result)) { + reply(llvm::make_error<MCPError>(err->message, err->code)); + return; + } + Result result; + llvm::json::Path::Root root; + if (!fromJSON(std::get<llvm::json::Value>(resp.result), result, root)) { + reply(llvm::make_error<MCPError>("parsing response failed: " + + llvm::toString(root.getError()))); + return; + } + reply(result); + }); + }; + } + + /// Bind a function object to be used for outgoing notifications. + /// e.g. OutgoingNotification<LogParams> Log = Bind.outgoingMethod("log"); + /// LogParams must be JSON-serializable. + template <typename Params> + OutgoingNotification<Params> + outgoingNotification(llvm::StringLiteral method) { + return [this, method](const Params ¶ms) { + Notification note; + note.method = method; + note.params = toJSON(params); + m_handlers.Send(note); + }; + } + + operator MCPTransport::MessageHandler &() { return m_handlers; } + +private: + class RawHandler final : public MCPTransport::MessageHandler { + public: + explicit RawHandler(MCPTransport *transport); + + void Received(const Notification ¬e) override; + void Received(const Request &req) override; + void Received(const Response &resp) override; + void OnError(llvm::Error err) override; + void OnClosed() override; + + void Send(const Request &req, + Callback<void(const Response &)> response_handler); + void Send(const Notification ¬e); + void Send(const Response &resp); + + friend class Binder; + + private: + std::recursive_mutex m_mutex; + MCPTransport *m_transport; + int m_seq = 0; + std::map<Id, Callback<void(const Response &)>> m_pending_responses; + llvm::StringMap< + Callback<void(const Request &, Callback<void(const Response &)>)>> + m_request_handlers; + llvm::StringMap<Callback<void(const Notification &)>> + m_notification_handlers; + Callback<void(MCPTransport *)> m_disconnect_handler; + Callback<void(MCPTransport *, llvm::Error)> m_error_handler; + }; + + RawHandler m_handlers; +}; +using BinderUP = std::unique_ptr<Binder>; + +} // namespace lldb_protocol::mcp + +#endif diff --git a/lldb/include/lldb/Protocol/MCP/Protocol.h b/lldb/include/lldb/Protocol/MCP/Protocol.h index 49f9490221755..d21a5ef85ece6 100644 --- a/lldb/include/lldb/Protocol/MCP/Protocol.h +++ b/lldb/include/lldb/Protocol/MCP/Protocol.h @@ -14,10 +14,12 @@ #ifndef LLDB_PROTOCOL_MCP_PROTOCOL_H #define LLDB_PROTOCOL_MCP_PROTOCOL_H +#include "lldb/lldb-types.h" #include "llvm/Support/JSON.h" #include <optional> #include <string> #include <variant> +#include <vector> namespace lldb_protocol::mcp { @@ -43,6 +45,12 @@ llvm::json::Value toJSON(const Request &); bool fromJSON(const llvm::json::Value &, Request &, llvm::json::Path); bool operator==(const Request &, const Request &); +enum ErrorCode : signed { + eErrorCodeMethodNotFound = -32601, + eErrorCodeInvalidParams = -32602, + eErrorCodeInternalError = -32000, +}; + struct Error { /// The error type that occurred. int64_t code = 0; @@ -147,6 +155,14 @@ struct Resource { llvm::json::Value toJSON(const Resource &); bool fromJSON(const llvm::json::Value &, Resource &, llvm::json::Path); +/// The server’s response to a resources/list request from the client. +struct ResourcesListResult { + std::vector<Resource> resources; +}; +llvm::json::Value toJSON(const ResourcesListResult &); +bool fromJSON(const llvm::json::Value &, ResourcesListResult &, + llvm::json::Path); + /// The contents of a specific resource or sub-resource. struct ResourceContents { /// The URI of this resource. @@ -163,13 +179,23 @@ struct ResourceContents { llvm::json::Value toJSON(const ResourceContents &); bool fromJSON(const llvm::json::Value &, ResourceContents &, llvm::json::Path); +/// Sent from the client to the server, to read a specific resource URI. +struct ResourcesReadParams { + /// The URI of the resource to read. The URI can use any protocol; it is up to + /// the server how to interpret it. + std::string URI; +}; +llvm::json::Value toJSON(const ResourcesReadParams &); +bool fromJSON(const llvm::json::Value &, ResourcesReadParams &, + llvm::json::Path); + /// The server's response to a resources/read request from the client. -struct ResourceResult { +struct ResourcesReadResult { std::vector<ResourceContents> contents; }; - -llvm::json::Value toJSON(const ResourceResult &); -bool fromJSON(const llvm::json::Value &, ResourceResult &, llvm::json::Path); +llvm::json::Value toJSON(const ResourcesReadResult &); +bool fromJSON(const llvm::json::Value &, ResourcesReadResult &, + llvm::json::Path); /// Text provided to or from an LLM. struct TextContent { @@ -204,6 +230,145 @@ bool fromJSON(const llvm::json::Value &, ToolDefinition &, llvm::json::Path); using ToolArguments = std::variant<std::monostate, llvm::json::Value>; +/// Describes the name and version of an MCP implementation, with an optional +/// title for UI representation. +/// +/// see +/// https://modelcontextprotocol.io/specification/2025-06-18/schema#implementation +struct Implementation { + /// Intended for programmatic or logical use, but used as a display name in + /// past specs or fallback (if title isn’t present). + std::string name; + + /// Intended for UI and end-user contexts — optimized to be human-readable and + /// easily understood, even by those unfamiliar with domain-specific + /// terminology. + /// + /// If not provided, the name should be used for display (except for Tool, + /// where annotations.title should be given precedence over using name, if + /// present). + std::string title; + + std::string version; +}; +llvm::json::Value toJSON(const Implementation &); +bool fromJSON(const llvm::json::Value &, Implementation &, llvm::json::Path); + +/// Capabilities a client may support. Known capabilities are defined here, in +/// this schema, but this is not a closed set: any client can define its own, +/// additional capabilities. +struct ClientCapabilities {}; +llvm::json::Value toJSON(const ClientCapabilities &); +bool fromJSON(const llvm::json::Value &, ClientCapabilities &, + llvm::json::Path); + +/// Capabilities that a server may support. Known capabilities are defined here, +/// in this schema, but this is not a closed set: any server can define its own, +/// additional capabilities. +struct ServerCapabilities { + bool supportsToolsList = false; + bool supportsResourcesList = false; + bool supportsResourcesSubscribe = false; + + /// Utilities. + bool supportsCompletions = false; + bool supportsLogging = false; +}; +llvm::json::Value toJSON(const ServerCapabilities &); +bool fromJSON(const llvm::json::Value &, ServerCapabilities &, + llvm::json::Path); + +/// Initialization + +/// This request is sent from the client to the server when it first connects, +/// asking it to begin initialization. +/// +/// @category initialize +struct InitializeParams { + /// The latest version of the Model Context Protocol that the client supports. + /// The client MAY decide to support older versions as well. + std::string protocolVersion; + + ClientCapabilities capabilities; + + Implementation clientInfo; +}; +llvm::json::Value toJSON(const InitializeParams &); +bool fromJSON(const llvm::json::Value &, InitializeParams &, llvm::json::Path); + +/// After receiving an initialize request from the client, the server sends this +/// response. +/// +/// @category initialize +struct InitializeResult { + /// The version of the Model Context Protocol that the server wants to use. + /// This may not match the version that the client requested. If the client + /// cannot support this version, it MUST disconnect. + std::string protocolVersion; + + ServerCapabilities capabilities; + Implementation serverInfo; + + /// Instructions describing how to use the server and its features. + /// + /// This can be used by clients to improve the LLM's understanding of + /// available tools, resources, etc. It can be thought of like a "hint" to the + /// model. For example, this information MAY be added to the system prompt. + std::string instructions; +}; +llvm::json::Value toJSON(const InitializeResult &); +bool fromJSON(const llvm::json::Value &, InitializeResult &, llvm::json::Path); + +/// Special case parameter. +using Void = std::monostate; +llvm::json::Value toJSON(const Void &); +bool fromJSON(const llvm::json::Value &, Void &, llvm::json::Path); + +/// The server's response to a `tools/list` request from the client. +struct ToolsListResult { + std::vector<ToolDefinition> tools; +}; +llvm::json::Value toJSON(const ToolsListResult &); +bool fromJSON(const llvm::json::Value &, ToolsListResult &, llvm::json::Path); + +// FIXME: Support other content types as needed. +using ContentBlock = TextContent; + +/// Used by the client to invoke a tool provided by the server. +struct ToolsCallParams { + std::string name; + std::optional<llvm::json::Value> arguments; +}; +llvm::json::Value toJSON(const ToolsCallParams &); +bool fromJSON(const llvm::json::Value &, ToolsCallParams &, llvm::json::Path); + +/// The server’s response to a tool call. +struct ToolsCallResult { + /// A list of content objects that represent the unstructured result of the + /// tool call. + std::vector<ContentBlock> content; + + /// Whether the tool call ended in an error. + /// + /// If not set, this is assumed to be false (the call was successful). + /// + /// Any errors that originate from the tool SHOULD be reported inside the + /// result object, with `isError` set to true, not as an MCP protocol-level + /// error response. Otherwise, the LLM would not be able to see that an error + /// occurred and self-correct. + /// + /// However, any errors in finding the tool, an error indicating that the + /// server does not support tool calls, or any other exceptional conditions, + /// should be reported as an MCP error response. + bool isError = false; + + /// An optional JSON object that represents the structured result of the tool + /// call. + std::optional<llvm::json::Value> structuredContent; +}; +llvm::json::Value toJSON(const ToolsCallResult &); +bool fromJSON(const llvm::json::Value &, ToolsCallResult &, llvm::json::Path); + } // namespace lldb_protocol::mcp #endif diff --git a/lldb/include/lldb/Protocol/MCP/Resource.h b/lldb/include/lldb/Protocol/MCP/Resource.h index 4835d340cd4c6..8a3e3ca725eb5 100644 --- a/lldb/include/lldb/Protocol/MCP/Resource.h +++ b/lldb/include/lldb/Protocol/MCP/Resource.h @@ -20,7 +20,7 @@ class ResourceProvider { virtual ~ResourceProvider() = default; virtual std::vector<lldb_protocol::mcp::Resource> GetResources() const = 0; - virtual llvm::Expected<lldb_protocol::mcp::ResourceResult> + virtual llvm::Expected<lldb_protocol::mcp::ResourcesReadResult> ReadResource(llvm::StringRef uri) const = 0; }; diff --git a/lldb/include/lldb/Protocol/MCP/Server.h b/lldb/include/lldb/Protocol/MCP/Server.h index 382f9a4731dd4..d749f8d493153 100644 --- a/lldb/include/lldb/Protocol/MCP/Server.h +++ b/lldb/include/lldb/Protocol/MCP/Server.h @@ -9,82 +9,52 @@ #ifndef LLDB_PROTOCOL_MCP_SERVER_H #define LLDB_PROTOCOL_MCP_SERVER_H -#include "lldb/Host/JSONTransport.h" #include "lldb/Host/MainLoop.h" +#include "lldb/Protocol/MCP/Binder.h" #include "lldb/Protocol/MCP/Protocol.h" #include "lldb/Protocol/MCP/Resource.h" #include "lldb/Protocol/MCP/Tool.h" +#include "lldb/Protocol/MCP/Transport.h" #include "llvm/ADT/StringMap.h" #include "llvm/Support/Error.h" +#include <memory> #include <mutex> -namespace lldb_protocol::mcp { - -class MCPTransport final - : public lldb_private::JSONRPCTransport<Request, Response, Notification> { -public: - using LogCallback = std::function<void(llvm::StringRef message)>; - - MCPTransport(lldb::IOObjectSP in, lldb::IOObjectSP out, - std::string client_name, LogCallback log_callback = {}) - : JSONRPCTransport(in, out), m_client_name(std::move(client_name)), - m_log_callback(log_callback) {} - virtual ~MCPTransport() = default; +namespace lldb_private::mcp { +class ProtocolServerMCP; +} // namespace lldb_private::mcp - void Log(llvm::StringRef message) override { - if (m_log_callback) - m_log_callback(llvm::formatv("{0}: {1}", m_client_name, message).str()); - } +namespace lldb_protocol::mcp { -private: - std::string m_client_name; - LogCallback m_log_callback; -}; +class Server { + friend class lldb_private::mcp::ProtocolServerMCP; + friend class lldb_private::mcp::ProtocolServerMCP; -class Server : public MCPTransport::MessageHandler { public: Server(std::string name, std::string version, std::unique_ptr<MCPTransport> transport_up, lldb_private::MainLoop &loop); ~Server() = default; - using NotificationHandler = std::function<void(const Notification &)>; - void AddTool(std::unique_ptr<Tool> tool); void AddResourceProvider(std::unique_ptr<ResourceProvider> resource_provider); - void AddNotificationHandler(llvm::StringRef method, - NotificationHandler handler); llvm::Error Run(); -protected: - Capabilities GetCapabilities(); - - using RequestHandler = - std::function<llvm::Expected<Response>(const Request &)>; - - void AddRequestHandlers(); - - void AddRequestHandler(llvm::StringRef method, RequestHandler handler); + Binder &GetBinder() { return m_binder; }; - llvm::Expected<std::optional<Message>> HandleData(llvm::StringRef data); - - llvm::Expected<Response> Handle(Request request); - void Handle(Notification notification); - - llvm::Expected<Response> InitializeHandler(const Request &); +protected: + ServerCapabilities GetCapabilities(); - llvm::Expected<Response> ToolsListHandler(const Request &); - llvm::Expected<Response> ToolsCallHandler(const Request &); + llvm::Expected<InitializeResult> + InitializeHandler(const InitializeParams &request); - llvm::Expected<Response> ResourcesListHandler(const Request &); - llvm::Expected<Response> ResourcesReadHandler(const Request &); + llvm::Expected<ToolsListResult> ToolsListHandler(const Void &); + llvm::Expected<ToolsCallResult> ToolsCallHandler(const ToolsCallParams &); - void Received(const Request &) override; - void Received(const Response &) override; - void Received(const Notification &) override; - void OnError(llvm::Error) override; - void OnClosed() override; + llvm::Expected<ResourcesListResult> ResourcesListHandler(const Void &); + llvm::Expected<ResourcesReadResult> + ResourcesReadHandler(const ResourcesReadParams &); void TerminateLoop(); @@ -99,9 +69,7 @@ class Server : public MCPTransport::MessageHandler { llvm::StringMap<std::unique_ptr<Tool>> m_tools; std::vector<std::unique_ptr<ResourceProvider>> m_resource_providers; - - llvm::StringMap<RequestHandler> m_request_handlers; - llvm::StringMap<NotificationHandler> m_notification_handlers; + Binder m_binder; }; } // namespace lldb_protocol::mcp diff --git a/lldb/include/lldb/Protocol/MCP/Tool.h b/lldb/include/lldb/Protocol/MCP/Tool.h index 96669d1357166..26cbc943f0704 100644 --- a/lldb/include/lldb/Protocol/MCP/Tool.h +++ b/lldb/include/lldb/Protocol/MCP/Tool.h @@ -10,6 +10,8 @@ #define LLDB_PROTOCOL_MCP_TOOL_H #include "lldb/Protocol/MCP/Protocol.h" +#include "llvm/ADT/FunctionExtras.h" +#include "llvm/Support/Error.h" #include "llvm/Support/JSON.h" #include <string> @@ -20,8 +22,11 @@ class Tool { Tool(std::string name, std::string description); virtual ~Tool() = default; - virtual llvm::Expected<lldb_protocol::mcp::TextResult> - Call(const lldb_protocol::mcp::ToolArguments &args) = 0; + using Reply = llvm::unique_function<void( + llvm::Expected<lldb_protocol::mcp::ToolsCallResult>)>; + + virtual void Call(const lldb_protocol::mcp::ToolArguments &args, + Reply reply) = 0; virtual std::optional<llvm::json::Value> GetSchema() const { return llvm::json::Object{{"type", "object"}}; diff --git a/lldb/include/lldb/Protocol/MCP/Transport.h b/lldb/include/lldb/Protocol/MCP/Transport.h new file mode 100644 index 0000000000000..efbddc6d31d17 --- /dev/null +++ b/lldb/include/lldb/Protocol/MCP/Transport.h @@ -0,0 +1,50 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLDB_PROTOCOL_MCP_TRANSPORT_H +#define LLDB_PROTOCOL_MCP_TRANSPORT_H + +#include "lldb/Host/JSONTransport.h" +#include "lldb/Host/Socket.h" +#include "lldb/Protocol/MCP/Protocol.h" +#include "lldb/lldb-forward.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/raw_ostream.h" +#include <memory> + +namespace lldb_protocol::mcp { + +using MCPTransport = lldb_private::Transport<Request, Response, Notification>; +using MCPTransportUP = std::unique_ptr<MCPTransport>; + +llvm::StringRef CommunicationSocketPath(); +llvm::Expected<lldb::IOObjectSP> Connect(); + +class Transport final + : public lldb_private::JSONRPCTransport<Request, Response, Notification> { +public: + using LogCallback = std::function<void(llvm::StringRef message)>; + + Transport(lldb::IOObjectSP input, lldb::IOObjectSP output, + std::string client_name = "", LogCallback log_callback = {}); + + void Log(llvm::StringRef message) override; + + static llvm::Expected<MCPTransportUP> + Connect(llvm::raw_ostream *logger = nullptr); + +private: + std::string m_client_name; + LogCallback m_log_callback; +}; +using TransportUP = std::unique_ptr<Transport>; + +} // namespace lldb_protocol::mcp + +#endif diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp index 57132534cf680..15558b4e7c914 100644 --- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp @@ -10,14 +10,10 @@ #include "Resource.h" #include "Tool.h" #include "lldb/Core/PluginManager.h" -#include "lldb/Protocol/MCP/MCPError.h" -#include "lldb/Protocol/MCP/Tool.h" #include "lldb/Utility/LLDBLog.h" #include "lldb/Utility/Log.h" -#include "llvm/ADT/StringExtras.h" #include "llvm/Support/Threading.h" #include <thread> -#include <variant> using namespace lldb_private; using namespace lldb_private::mcp; @@ -50,12 +46,14 @@ llvm::StringRef ProtocolServerMCP::GetPluginDescriptionStatic() { return "MCP Server."; } -void ProtocolServerMCP::Extend(lldb_protocol::mcp::Server &server) const { - server.AddNotificationHandler("notifications/initialized", - [](const lldb_protocol::mcp::Notification &) { - LLDB_LOG(GetLog(LLDBLog::Host), - "MCP initialization complete"); - }); +void ProtocolServerMCP::OnInitialized( + const lldb_protocol::mcp::Notification &) { + LLDB_LOG(GetLog(LLDBLog::Host), "MCP initialization complete"); +} + +void ProtocolServerMCP::Extend(lldb_protocol::mcp::Server &server) { + server.m_binder.notification("notifications/initialized", + &ProtocolServerMCP::OnInitialized, this); server.AddTool( std::make_unique<CommandTool>("lldb_command", "Run an lldb command.")); server.AddResourceProvider(std::make_unique<DebuggerResourceProvider>()); @@ -67,7 +65,7 @@ void ProtocolServerMCP::AcceptCallback(std::unique_ptr<Socket> socket) { LLDB_LOG(log, "New MCP client connected: {0}", client_name); lldb::IOObjectSP io_sp = std::move(socket); - auto transport_up = std::make_unique<lldb_protocol::mcp::MCPTransport>( + auto transport_up = std::make_unique<lldb_protocol::mcp::Transport>( io_sp, io_sp, std::move(client_name), [&](llvm::StringRef message) { LLDB_LOG(GetLog(LLDBLog::Host), "{0}", message); }); diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h index fc650ffe0dfa7..d35f7f678b2c4 100644 --- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h @@ -41,7 +41,9 @@ class ProtocolServerMCP : public ProtocolServer { protected: // This adds tools and resource providers that // are specific to this server. Overridable by the unit tests. - virtual void Extend(lldb_protocol::mcp::Server &server) const; + virtual void Extend(lldb_protocol::mcp::Server &server); + + void OnInitialized(const lldb_protocol::mcp::Notification &); private: void AcceptCallback(std::unique_ptr<Socket> socket); diff --git a/lldb/source/Plugins/Protocol/MCP/Resource.cpp b/lldb/source/Plugins/Protocol/MCP/Resource.cpp index e94d2cdd65e07..b5f0a6569654b 100644 --- a/lldb/source/Plugins/Protocol/MCP/Resource.cpp +++ b/lldb/source/Plugins/Protocol/MCP/Resource.cpp @@ -124,7 +124,7 @@ DebuggerResourceProvider::GetResources() const { return resources; } -llvm::Expected<lldb_protocol::mcp::ResourceResult> +llvm::Expected<lldb_protocol::mcp::ResourcesReadResult> DebuggerResourceProvider::ReadResource(llvm::StringRef uri) const { auto [protocol, path] = uri.split("://"); @@ -161,7 +161,7 @@ DebuggerResourceProvider::ReadResource(llvm::StringRef uri) const { return ReadDebuggerResource(uri, debugger_idx); } -llvm::Expected<lldb_protocol::mcp::ResourceResult> +llvm::Expected<lldb_protocol::mcp::ResourcesReadResult> DebuggerResourceProvider::ReadDebuggerResource(llvm::StringRef uri, lldb::user_id_t debugger_id) { lldb::DebuggerSP debugger_sp = Debugger::FindDebuggerWithID(debugger_id); @@ -178,12 +178,12 @@ DebuggerResourceProvider::ReadDebuggerResource(llvm::StringRef uri, contents.mimeType = kMimeTypeJSON; contents.text = llvm::formatv("{0}", toJSON(debugger_resource)); - lldb_protocol::mcp::ResourceResult result; + lldb_protocol::mcp::ResourcesReadResult result; result.contents.push_back(contents); return result; } -llvm::Expected<lldb_protocol::mcp::ResourceResult> +llvm::Expected<lldb_protocol::mcp::ResourcesReadResult> DebuggerResourceProvider::ReadTargetResource(llvm::StringRef uri, lldb::user_id_t debugger_id, size_t target_idx) { @@ -214,7 +214,7 @@ DebuggerResourceProvider::ReadTargetResource(llvm::StringRef uri, contents.mimeType = kMimeTypeJSON; contents.text = llvm::formatv("{0}", toJSON(target_resource)); - lldb_protocol::mcp::ResourceResult result; + lldb_protocol::mcp::ResourcesReadResult result; result.contents.push_back(contents); return result; } diff --git a/lldb/source/Plugins/Protocol/MCP/Resource.h b/lldb/source/Plugins/Protocol/MCP/Resource.h index e2382a74f796b..0810f1fb0c4f4 100644 --- a/lldb/source/Plugins/Protocol/MCP/Resource.h +++ b/lldb/source/Plugins/Protocol/MCP/Resource.h @@ -23,7 +23,7 @@ class DebuggerResourceProvider : public lldb_protocol::mcp::ResourceProvider { virtual std::vector<lldb_protocol::mcp::Resource> GetResources() const override; - virtual llvm::Expected<lldb_protocol::mcp::ResourceResult> + virtual llvm::Expected<lldb_protocol::mcp::ResourcesReadResult> ReadResource(llvm::StringRef uri) const override; private: @@ -31,9 +31,9 @@ class DebuggerResourceProvider : public lldb_protocol::mcp::ResourceProvider { static lldb_protocol::mcp::Resource GetTargetResource(size_t target_idx, Target &target); - static llvm::Expected<lldb_protocol::mcp::ResourceResult> + static llvm::Expected<lldb_protocol::mcp::ResourcesReadResult> ReadDebuggerResource(llvm::StringRef uri, lldb::user_id_t debugger_id); - static llvm::Expected<lldb_protocol::mcp::ResourceResult> + static llvm::Expected<lldb_protocol::mcp::ResourcesReadResult> ReadTargetResource(llvm::StringRef uri, lldb::user_id_t debugger_id, size_t target_idx); }; diff --git a/lldb/source/Plugins/Protocol/MCP/Tool.cpp b/lldb/source/Plugins/Protocol/MCP/Tool.cpp index 143470702a6fd..dabf100874b62 100644 --- a/lldb/source/Plugins/Protocol/MCP/Tool.cpp +++ b/lldb/source/Plugins/Protocol/MCP/Tool.cpp @@ -14,6 +14,7 @@ using namespace lldb_private; using namespace lldb_protocol; using namespace lldb_private::mcp; +using namespace lldb_protocol::mcp; using namespace llvm; namespace { @@ -29,10 +30,10 @@ bool fromJSON(const llvm::json::Value &V, CommandToolArguments &A, O.mapOptional("arguments", A.arguments); } -/// Helper function to create a TextResult from a string output. -static lldb_protocol::mcp::TextResult createTextResult(std::string output, - bool is_error = false) { - lldb_protocol::mcp::TextResult text_result; +/// Helper function to create a ToolsCallResult from a string output. +static lldb_protocol::mcp::ToolsCallResult +createTextResult(std::string output, bool is_error = false) { + lldb_protocol::mcp::ToolsCallResult text_result; text_result.content.emplace_back( lldb_protocol::mcp::TextContent{{std::move(output)}}); text_result.isError = is_error; @@ -41,22 +42,23 @@ static lldb_protocol::mcp::TextResult createTextResult(std::string output, } // namespace -llvm::Expected<lldb_protocol::mcp::TextResult> -CommandTool::Call(const lldb_protocol::mcp::ToolArguments &args) { +namespace lldb_private::mcp { + +void CommandTool::Call(const ToolArguments &args, Reply reply) { if (!std::holds_alternative<json::Value>(args)) - return createStringError("CommandTool requires arguments"); + return reply(createStringError("CommandTool requires arguments")); json::Path::Root root; CommandToolArguments arguments; if (!fromJSON(std::get<json::Value>(args), arguments, root)) - return root.getError(); + return reply(root.getError()); lldb::DebuggerSP debugger_sp = Debugger::FindDebuggerWithID(arguments.debugger_id); if (!debugger_sp) - return createStringError( - llvm::formatv("no debugger with id {0}", arguments.debugger_id)); + return reply(createStringError( + llvm::formatv("no debugger with id {0}", arguments.debugger_id))); // FIXME: Disallow certain commands and their aliases. CommandReturnObject result(/*colors=*/false); @@ -75,7 +77,7 @@ CommandTool::Call(const lldb_protocol::mcp::ToolArguments &args) { output += err_str; } - return createTextResult(output, !result.Succeeded()); + reply(createTextResult(output, !result.Succeeded())); } std::optional<llvm::json::Value> CommandTool::GetSchema() const { @@ -89,3 +91,5 @@ std::optional<llvm::json::Value> CommandTool::GetSchema() const { {"required", std::move(required)}}; return schema; } + +} // namespace lldb_private::mcp diff --git a/lldb/source/Plugins/Protocol/MCP/Tool.h b/lldb/source/Plugins/Protocol/MCP/Tool.h index b7b1756eb38d7..4fc5884174e01 100644 --- a/lldb/source/Plugins/Protocol/MCP/Tool.h +++ b/lldb/source/Plugins/Protocol/MCP/Tool.h @@ -22,10 +22,11 @@ class CommandTool : public lldb_protocol::mcp::Tool { using lldb_protocol::mcp::Tool::Tool; ~CommandTool() = default; - virtual llvm::Expected<lldb_protocol::mcp::TextResult> - Call(const lldb_protocol::mcp::ToolArguments &args) override; + void Call(const lldb_protocol::mcp::ToolArguments &, + llvm::unique_function<void( + llvm::Expected<lldb_protocol::mcp::ToolsCallResult>)>) override; - virtual std::optional<llvm::json::Value> GetSchema() const override; + std::optional<llvm::json::Value> GetSchema() const override; }; } // namespace lldb_private::mcp diff --git a/lldb/source/Protocol/MCP/Binder.cpp b/lldb/source/Protocol/MCP/Binder.cpp new file mode 100644 index 0000000000000..90ae39ba0e3f0 --- /dev/null +++ b/lldb/source/Protocol/MCP/Binder.cpp @@ -0,0 +1,139 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "lldb/Protocol/MCP/Binder.h" +#include "lldb/Protocol/MCP/Protocol.h" +#include "lldb/Protocol/MCP/Transport.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/FormatVariadic.h" +#include <atomic> +#include <cassert> +#include <mutex> + +using namespace llvm; + +namespace lldb_protocol::mcp { + +/// Function object to reply to a call. +/// Each instance must be called exactly once, otherwise: +/// - the bug is logged, and (in debug mode) an assert will fire +/// - if there was no reply, an error reply is sent +/// - if there were multiple replies, only the first is sent +class ReplyOnce { + std::atomic<bool> replied = {false}; + const Id id; + MCPTransport *transport; // Null when moved-from. + MCPTransport::MessageHandler *handler; // Null when moved-from. + +public: + ReplyOnce(const Id id, MCPTransport *transport, + MCPTransport::MessageHandler *handler) + : id(id), transport(transport), handler(handler) { + assert(handler); + } + ReplyOnce(ReplyOnce &&other) + : replied(other.replied.load()), id(other.id), transport(other.transport), + handler(other.handler) { + other.transport = nullptr; + other.handler = nullptr; + } + ReplyOnce &operator=(ReplyOnce &&) = delete; + ReplyOnce(const ReplyOnce &) = delete; + ReplyOnce &operator=(const ReplyOnce &) = delete; + + ~ReplyOnce() { + if (transport && handler && !replied) { + assert(false && "must reply to all calls!"); + (*this)(Response{id, Error{MCPError::kInternalError, "failed to reply", + std::nullopt}}); + } + } + + void operator()(const Response &resp) { + assert(transport && handler && "moved-from!"); + if (replied.exchange(true)) { + assert(false && "must reply to each call only once!"); + return; + } + + if (llvm::Error error = transport->Send(Response{id, resp.result})) + handler->OnError(std::move(error)); + } +}; + +Binder::RawHandler::RawHandler(MCPTransport *transport) + : m_transport(transport) {} + +void Binder::RawHandler::Received(const Notification ¬e) { + std::scoped_lock<std::recursive_mutex> guard(m_mutex); + auto it = m_notification_handlers.find(note.method); + if (it == m_notification_handlers.end()) { + OnError(llvm::createStringError( + formatv("no handled for notification {0}", toJSON(note)))); + return; + } + it->second(note); +} + +void Binder::RawHandler::Received(const Request &req) { + ReplyOnce reply(req.id, m_transport, this); + + std::scoped_lock<std::recursive_mutex> guard(m_mutex); + auto it = m_request_handlers.find(req.method); + if (it == m_request_handlers.end()) { + reply({req.id, + Error{eErrorCodeMethodNotFound, "method not found", std::nullopt}}); + return; + } + + it->second(req, std::move(reply)); +} + +void Binder::RawHandler::Received(const Response &resp) { + std::scoped_lock<std::recursive_mutex> guard(m_mutex); + auto it = m_pending_responses.find(resp.id); + if (it == m_pending_responses.end()) { + OnError(llvm::createStringError( + formatv("no pending request for {0}", toJSON(resp)))); + return; + } + + it->second(resp); + m_pending_responses.erase(it); +} + +void Binder::RawHandler::OnError(llvm::Error err) { + std::scoped_lock<std::recursive_mutex> guard(m_mutex); + if (m_error_handler) + m_error_handler(m_transport, std::move(err)); +} + +void Binder::RawHandler::OnClosed() { + std::scoped_lock<std::recursive_mutex> guard(m_mutex); + if (m_disconnect_handler) + m_disconnect_handler(m_transport); +} + +void Binder::RawHandler::Send( + const Request &req, + llvm::unique_function<void(const Response &)> response_handler) { + std::lock_guard<std::recursive_mutex> guard(m_mutex); + Id id = ++m_seq; + if (llvm::Error err = m_transport->Send(Request{id, req.method, req.params})) + return OnError(std::move(err)); + m_pending_responses[id] = std::move(response_handler); +} + +void Binder::RawHandler::Send(const Notification ¬e) { + std::lock_guard<std::recursive_mutex> guard(m_mutex); + if (llvm::Error err = m_transport->Send(note)) + return OnError(std::move(err)); +} + +} // namespace lldb_protocol::mcp diff --git a/lldb/source/Protocol/MCP/CMakeLists.txt b/lldb/source/Protocol/MCP/CMakeLists.txt index a73e7e6a7cab1..e6e8200833ffd 100644 --- a/lldb/source/Protocol/MCP/CMakeLists.txt +++ b/lldb/source/Protocol/MCP/CMakeLists.txt @@ -1,12 +1,15 @@ add_lldb_library(lldbProtocolMCP NO_PLUGIN_DEPENDENCIES + Binder.cpp MCPError.cpp Protocol.cpp Server.cpp Tool.cpp + Transport.cpp LINK_COMPONENTS Support LINK_LIBS lldbUtility + lldbHost ) diff --git a/lldb/source/Protocol/MCP/Protocol.cpp b/lldb/source/Protocol/MCP/Protocol.cpp index 65ddfaee70160..8a976bb797d32 100644 --- a/lldb/source/Protocol/MCP/Protocol.cpp +++ b/lldb/source/Protocol/MCP/Protocol.cpp @@ -228,11 +228,11 @@ bool fromJSON(const llvm::json::Value &V, ResourceContents &RC, O.mapOptional("mimeType", RC.mimeType); } -llvm::json::Value toJSON(const ResourceResult &RR) { +llvm::json::Value toJSON(const ResourcesReadResult &RR) { return llvm::json::Object{{"contents", RR.contents}}; } -bool fromJSON(const llvm::json::Value &V, ResourceResult &RR, +bool fromJSON(const llvm::json::Value &V, ResourcesReadResult &RR, llvm::json::Path P) { llvm::json::ObjectMapper O(V, P); return O && O.map("contents", RR.contents); @@ -325,4 +325,159 @@ bool fromJSON(const llvm::json::Value &V, Message &M, llvm::json::Path P) { return false; } +json::Value toJSON(const Implementation &I) { + json::Object result{{"name", I.name}, {"version", I.version}}; + + if (!I.title.empty()) + result.insert({"title", I.title}); + + return result; +} + +bool fromJSON(const json::Value &V, Implementation &I, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("name", I.name) && O.mapOptional("title", I.title) && + O.mapOptional("version", I.version); +} + +json::Value toJSON(const ClientCapabilities &C) { return json::Object{}; } + +bool fromJSON(const json::Value &, ClientCapabilities &, json::Path) { + return true; +} + +json::Value toJSON(const ServerCapabilities &C) { + json::Object result{}; + + if (C.supportsToolsList) + result.insert({"tools", json::Object{{"listChanged", true}}}); + + if (C.supportsResourcesList || C.supportsResourcesSubscribe) { + json::Object resources; + if (C.supportsResourcesList) + resources.insert({"listChanged", true}); + if (C.supportsResourcesSubscribe) + resources.insert({"subscribe", true}); + result.insert({"resources", std::move(resources)}); + } + + if (C.supportsCompletions) + result.insert({"completions", json::Object{}}); + + if (C.supportsLogging) + result.insert({"logging", json::Object{}}); + + return result; +} + +bool fromJSON(const json::Value &V, ServerCapabilities &C, json::Path P) { + const json::Object *O = V.getAsObject(); + if (!O) { + P.report("expected object"); + return false; + } + + if (O->find("tools") != O->end()) + C.supportsToolsList = true; + + return true; +} + +json::Value toJSON(const InitializeParams &P) { + return json::Object{ + {"protocolVersion", P.protocolVersion}, + {"capabilities", P.capabilities}, + {"clientInfo", P.clientInfo}, + }; +} + +bool fromJSON(const json::Value &V, InitializeParams &I, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("protocolVersion", I.protocolVersion) && + O.map("capabilities", I.capabilities) && + O.map("clientInfo", I.clientInfo); +} + +json::Value toJSON(const InitializeResult &R) { + json::Object result{{"protocolVersion", R.protocolVersion}, + {"capabilities", R.capabilities}, + {"serverInfo", R.serverInfo}}; + + if (!R.instructions.empty()) + result.insert({"instructions", R.instructions}); + + return result; +} + +bool fromJSON(const json::Value &V, InitializeResult &R, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("protocolVersion", R.protocolVersion) && + O.map("capabilities", R.capabilities) && + O.map("serverInfo", R.serverInfo) && + O.mapOptional("instructions", R.instructions); +} + +json::Value toJSON(const ToolsListResult &R) { + return json::Object{{"tools", R.tools}}; +} + +bool fromJSON(const json::Value &V, ToolsListResult &R, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("tools", R.tools); +} + +json::Value toJSON(const ToolsCallResult &R) { + json::Object result{{"content", R.content}}; + + if (R.isError) + result.insert({"isError", R.isError}); + if (R.structuredContent) + result.insert({"structuredContent", *R.structuredContent}); + + return result; +} + +bool fromJSON(const json::Value &V, ToolsCallResult &R, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("content", R.content) && + O.mapOptional("isError", R.isError) && + mapRaw(V, "structuredContent", R.structuredContent, P); +} + +json::Value toJSON(const ToolsCallParams &R) { + json::Object result{{"name", R.name}}; + + if (R.arguments) + result.insert({"arguments", *R.arguments}); + + return result; +} + +bool fromJSON(const json::Value &V, ToolsCallParams &R, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("name", R.name) && mapRaw(V, "arguments", R.arguments, P); +} + +json::Value toJSON(const ResourcesReadParams &R) { + return json::Object{{"uri", R.URI}}; +} + +bool fromJSON(const json::Value &V, ResourcesReadParams &R, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("uri", R.URI); +} + +json::Value toJSON(const ResourcesListResult &R) { + return json::Object{{"resources", R.resources}}; +} + +bool fromJSON(const json::Value &V, ResourcesListResult &R, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("resources", R.resources); +} + +json::Value toJSON(const Void &R) { return json::Object{}; } + +bool fromJSON(const json::Value &V, Void &R, json::Path P) { return true; } + } // namespace lldb_protocol::mcp diff --git a/lldb/source/Protocol/MCP/Server.cpp b/lldb/source/Protocol/MCP/Server.cpp index 3713e8e46c5d6..a612967d5fa51 100644 --- a/lldb/source/Protocol/MCP/Server.cpp +++ b/lldb/source/Protocol/MCP/Server.cpp @@ -7,8 +7,16 @@ //===----------------------------------------------------------------------===// #include "lldb/Protocol/MCP/Server.h" +#include "lldb/Host/Socket.h" +#include "lldb/Protocol/MCP/Binder.h" #include "lldb/Protocol/MCP/MCPError.h" +#include "lldb/Protocol/MCP/Protocol.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/Threading.h" +#include <future> +#include <memory> +using namespace lldb_private; using namespace lldb_protocol::mcp; using namespace llvm; @@ -16,83 +24,13 @@ Server::Server(std::string name, std::string version, std::unique_ptr<MCPTransport> transport_up, lldb_private::MainLoop &loop) : m_name(std::move(name)), m_version(std::move(version)), - m_transport_up(std::move(transport_up)), m_loop(loop) { - AddRequestHandlers(); -} - -void Server::AddRequestHandlers() { - AddRequestHandler("initialize", std::bind(&Server::InitializeHandler, this, - std::placeholders::_1)); - AddRequestHandler("tools/list", std::bind(&Server::ToolsListHandler, this, - std::placeholders::_1)); - AddRequestHandler("tools/call", std::bind(&Server::ToolsCallHandler, this, - std::placeholders::_1)); - AddRequestHandler("resources/list", std::bind(&Server::ResourcesListHandler, - this, std::placeholders::_1)); - AddRequestHandler("resources/read", std::bind(&Server::ResourcesReadHandler, - this, std::placeholders::_1)); -} - -llvm::Expected<Response> Server::Handle(Request request) { - auto it = m_request_handlers.find(request.method); - if (it != m_request_handlers.end()) { - llvm::Expected<Response> response = it->second(request); - if (!response) - return response; - response->id = request.id; - return *response; - } - - return llvm::make_error<MCPError>( - llvm::formatv("no handler for request: {0}", request.method).str()); -} - -void Server::Handle(Notification notification) { - auto it = m_notification_handlers.find(notification.method); - if (it != m_notification_handlers.end()) { - it->second(notification); - return; - } -} - -llvm::Expected<std::optional<Message>> -Server::HandleData(llvm::StringRef data) { - auto message = llvm::json::parse<Message>(/*JSON=*/data); - if (!message) - return message.takeError(); - - if (const Request *request = std::get_if<Request>(&(*message))) { - llvm::Expected<Response> response = Handle(*request); - - // Handle failures by converting them into an Error message. - if (!response) { - Error protocol_error; - llvm::handleAllErrors( - response.takeError(), - [&](const MCPError &err) { protocol_error = err.toProtocolError(); }, - [&](const llvm::ErrorInfoBase &err) { - protocol_error.code = MCPError::kInternalError; - protocol_error.message = err.message(); - }); - Response error_response; - error_response.id = request->id; - error_response.result = std::move(protocol_error); - return error_response; - } - - return *response; - } - - if (const Notification *notification = - std::get_if<Notification>(&(*message))) { - Handle(*notification); - return std::nullopt; - } - - if (std::get_if<Response>(&(*message))) - return llvm::createStringError("unexpected MCP message: response"); - - llvm_unreachable("all message types handled"); + m_transport_up(std::move(transport_up)), m_loop(loop), + m_binder(m_transport_up.get()) { + m_binder.request("initialize", &Server::InitializeHandler, this); + m_binder.request("tools/list", &Server::ToolsListHandler, this); + m_binder.request("tools/call", &Server::ToolsCallHandler, this); + m_binder.request("resources/list", &Server::ResourcesListHandler, this); + m_binder.request("resources/read", &Server::ResourcesReadHandler, this); } void Server::AddTool(std::unique_ptr<Tool> tool) { @@ -112,54 +50,30 @@ void Server::AddResourceProvider( m_resource_providers.push_back(std::move(resource_provider)); } -void Server::AddRequestHandler(llvm::StringRef method, RequestHandler handler) { - std::lock_guard<std::mutex> guard(m_mutex); - m_request_handlers[method] = std::move(handler); -} - -void Server::AddNotificationHandler(llvm::StringRef method, - NotificationHandler handler) { - std::lock_guard<std::mutex> guard(m_mutex); - m_notification_handlers[method] = std::move(handler); -} - -llvm::Expected<Response> Server::InitializeHandler(const Request &request) { - Response response; - response.result = llvm::json::Object{ - {"protocolVersion", mcp::kProtocolVersion}, - {"capabilities", GetCapabilities()}, - {"serverInfo", - llvm::json::Object{{"name", m_name}, {"version", m_version}}}}; - return response; +Expected<InitializeResult> +Server::InitializeHandler(const InitializeParams &request) { + InitializeResult result; + result.protocolVersion = mcp::kProtocolVersion; + result.capabilities = GetCapabilities(); + result.serverInfo = Implementation{m_name, "", m_version}; + return result; } -llvm::Expected<Response> Server::ToolsListHandler(const Request &request) { - Response response; +llvm::Expected<ToolsListResult> Server::ToolsListHandler(const Void &) { + ToolsListResult result; - llvm::json::Array tools; for (const auto &tool : m_tools) - tools.emplace_back(toJSON(tool.second->GetDefinition())); + result.tools.emplace_back(tool.second->GetDefinition()); - response.result = llvm::json::Object{{"tools", std::move(tools)}}; - - return response; + return result; } -llvm::Expected<Response> Server::ToolsCallHandler(const Request &request) { - Response response; - - if (!request.params) - return llvm::createStringError("no tool parameters"); - - const json::Object *param_obj = request.params->getAsObject(); - if (!param_obj) - return llvm::createStringError("no tool parameters"); - - const json::Value *name = param_obj->get("name"); - if (!name) +llvm::Expected<ToolsCallResult> +Server::ToolsCallHandler(const ToolsCallParams ¶ms) { + if (params.name.empty()) return llvm::createStringError("no tool name"); - llvm::StringRef tool_name = name->getAsString().value_or(""); + llvm::StringRef tool_name = params.name; if (tool_name.empty()) return llvm::createStringError("no tool name"); @@ -168,56 +82,41 @@ llvm::Expected<Response> Server::ToolsCallHandler(const Request &request) { return llvm::createStringError(llvm::formatv("no tool \"{0}\"", tool_name)); ToolArguments tool_args; - if (const json::Value *args = param_obj->get("arguments")) - tool_args = *args; + if (params.arguments) + tool_args = *params.arguments; - llvm::Expected<TextResult> text_result = it->second->Call(tool_args); - if (!text_result) - return text_result.takeError(); - - response.result = toJSON(*text_result); - - return response; + std::promise<llvm::Expected<ToolsCallResult>> result_promise; + it->second->Call(tool_args, + [&result_promise](llvm::Expected<ToolsCallResult> result) { + result_promise.set_value(std::move(result)); + }); + return result_promise.get_future().get(); } -llvm::Expected<Response> Server::ResourcesListHandler(const Request &request) { - Response response; - - llvm::json::Array resources; +llvm::Expected<ResourcesListResult> Server::ResourcesListHandler(const Void &) { + ResourcesListResult result; std::lock_guard<std::mutex> guard(m_mutex); for (std::unique_ptr<ResourceProvider> &resource_provider_up : - m_resource_providers) { + m_resource_providers) for (const Resource &resource : resource_provider_up->GetResources()) - resources.push_back(resource); - } - response.result = llvm::json::Object{{"resources", std::move(resources)}}; + result.resources.push_back(resource); - return response; + return result; } -llvm::Expected<Response> Server::ResourcesReadHandler(const Request &request) { - Response response; - - if (!request.params) - return llvm::createStringError("no resource parameters"); - - const json::Object *param_obj = request.params->getAsObject(); - if (!param_obj) - return llvm::createStringError("no resource parameters"); - - const json::Value *uri = param_obj->get("uri"); - if (!uri) - return llvm::createStringError("no resource uri"); +llvm::Expected<ResourcesReadResult> +Server::ResourcesReadHandler(const ResourcesReadParams ¶ms) { + ResourcesReadResult result; - llvm::StringRef uri_str = uri->getAsString().value_or(""); + llvm::StringRef uri_str = params.URI; if (uri_str.empty()) return llvm::createStringError("no resource uri"); std::lock_guard<std::mutex> guard(m_mutex); for (std::unique_ptr<ResourceProvider> &resource_provider_up : m_resource_providers) { - llvm::Expected<ResourceResult> result = + llvm::Expected<ResourcesReadResult> result = resource_provider_up->ReadResource(uri_str); if (result.errorIsA<UnsupportedURI>()) { llvm::consumeError(result.takeError()); @@ -225,10 +124,7 @@ llvm::Expected<Response> Server::ResourcesReadHandler(const Request &request) { } if (!result) return result.takeError(); - - Response response; - response.result = std::move(*result); - return response; + return *result; } return make_error<MCPError>( @@ -236,17 +132,18 @@ llvm::Expected<Response> Server::ResourcesReadHandler(const Request &request) { MCPError::kResourceNotFound); } -Capabilities Server::GetCapabilities() { - lldb_protocol::mcp::Capabilities capabilities; - capabilities.tools.listChanged = true; +ServerCapabilities Server::GetCapabilities() { + ServerCapabilities capabilities; + capabilities.supportsToolsList = true; + capabilities.supportsResourcesList = true; // FIXME: Support sending notifications when a debugger/target are // added/removed. - capabilities.resources.listChanged = false; + // capabilities.supportsResourcesSubscribe = true; return capabilities; } llvm::Error Server::Run() { - auto handle = m_transport_up->RegisterMessageHandler(m_loop, *this); + auto handle = m_transport_up->RegisterMessageHandler(m_loop, m_binder); if (!handle) return handle.takeError(); @@ -257,48 +154,6 @@ llvm::Error Server::Run() { return llvm::Error::success(); } -void Server::Received(const Request &request) { - auto SendResponse = [this](const Response &response) { - if (llvm::Error error = m_transport_up->Send(response)) - m_transport_up->Log(llvm::toString(std::move(error))); - }; - - llvm::Expected<Response> response = Handle(request); - if (response) - return SendResponse(*response); - - lldb_protocol::mcp::Error protocol_error; - llvm::handleAllErrors( - response.takeError(), - [&](const MCPError &err) { protocol_error = err.toProtocolError(); }, - [&](const llvm::ErrorInfoBase &err) { - protocol_error.code = MCPError::kInternalError; - protocol_error.message = err.message(); - }); - Response error_response; - error_response.id = request.id; - error_response.result = std::move(protocol_error); - SendResponse(error_response); -} - -void Server::Received(const Response &response) { - m_transport_up->Log("unexpected MCP message: response"); -} - -void Server::Received(const Notification ¬ification) { - Handle(notification); -} - -void Server::OnError(llvm::Error error) { - m_transport_up->Log(llvm::toString(std::move(error))); - TerminateLoop(); -} - -void Server::OnClosed() { - m_transport_up->Log("EOF"); - TerminateLoop(); -} - void Server::TerminateLoop() { m_loop.AddPendingCallback( [](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); }); diff --git a/lldb/source/Protocol/MCP/Transport.cpp b/lldb/source/Protocol/MCP/Transport.cpp new file mode 100644 index 0000000000000..28cf754aef3e8 --- /dev/null +++ b/lldb/source/Protocol/MCP/Transport.cpp @@ -0,0 +1,113 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "lldb/Protocol/MCP/Transport.h" +#include "lldb/Host/FileSystem.h" +#include "lldb/Host/HostInfo.h" +#include "lldb/Host/Socket.h" +#include "lldb/Utility/FileSpec.h" +#include "lldb/lldb-forward.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/Path.h" +#include "llvm/Support/Program.h" +#include "llvm/Support/Threading.h" +#include "llvm/Support/raw_ostream.h" +#include <memory> +#include <thread> + +using namespace llvm; +using namespace lldb; +using namespace lldb_private; + +namespace lldb_protocol::mcp { + +static Expected<sys::ProcessInfo> StartServer() { + static once_flag f; + static FileSpec candidate; + llvm::call_once(f, [] { + HostInfo::Initialize(); + candidate = HostInfo::GetSupportExeDir(); + candidate.AppendPathComponent("lldb-mcp"); + }); + + if (!FileSystem::Instance().Exists(candidate)) + return createStringError("lldb-mcp executable not found"); + std::vector<StringRef> args = {candidate.GetPath(), "--server"}; + sys::ProcessInfo proc = + sys::ExecuteNoWait(candidate.GetPath(), args, std::nullopt, {}, 0, + nullptr, nullptr, nullptr, /*DetachProcess=*/true); + if (proc.Pid == sys::ProcessInfo::InvalidPid) + return createStringError("Failed to start server: " + candidate.GetPath()); + StringRef socket_path = CommunicationSocketPath(); + while (!sys::fs::exists(socket_path)) + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + return proc; +} + +Transport::Transport(lldb::IOObjectSP input, lldb::IOObjectSP output, + std::string client_name, LogCallback log_callback) + : JSONRPCTransport(input, output), m_client_name(client_name), + m_log_callback(log_callback) {} + +void Transport::Log(llvm::StringRef message) { + if (m_log_callback) + m_log_callback(llvm::formatv("{0}: {1}", m_client_name, message).str()); +} + +llvm::StringRef CommunicationSocketPath() { + static std::once_flag f; + static SmallString<256> socket_path; + llvm::call_once(f, [] { + assert(sys::path::home_directory(socket_path) && + "failed to get home directory"); + sys::path::append(socket_path, ".lldb-mcp-sock"); + }); + return socket_path.str(); +} + +Expected<IOObjectSP> Connect() { + StringRef socket_path = CommunicationSocketPath(); + if (!sys::fs::exists(socket_path)) + if (llvm::Error err = StartServer().takeError()) + return err; + + Socket::SocketProtocol protocol = Socket::ProtocolUnixDomain; + Status error; + std::unique_ptr<Socket> socket = Socket::Create(protocol, error); + if (error.Fail()) + return error.takeError(); + std::chrono::steady_clock::time_point deadline = + std::chrono::steady_clock::now() + std::chrono::seconds(30); + while (std::chrono::steady_clock::now() < deadline) { + Status error = socket->Connect(socket_path); + if (error.Success()) { + return socket; + } + if (error.Fail() && error.GetError() != ECONNREFUSED && + error.GetError() != ENOENT) + return error.takeError(); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + + return createStringError("failed to connect to lldb-mcp multiplexer"); +} + +Expected<MCPTransportUP> Transport::Connect(llvm::raw_ostream *logger) { + Expected<IOObjectSP> maybe_sock = lldb_protocol::mcp::Connect(); + if (!maybe_sock) + return maybe_sock.takeError(); + + return std::make_unique<Transport>(*maybe_sock, *maybe_sock, "client", + [logger](StringRef msg) { + if (logger) + *logger << msg << "\n"; + }); +} + +} // namespace lldb_protocol::mcp diff --git a/lldb/unittests/Protocol/ProtocolMCPTest.cpp b/lldb/unittests/Protocol/ProtocolMCPTest.cpp index ea19922522ffe..45024b5ca9f3d 100644 --- a/lldb/unittests/Protocol/ProtocolMCPTest.cpp +++ b/lldb/unittests/Protocol/ProtocolMCPTest.cpp @@ -277,10 +277,11 @@ TEST(ProtocolMCPTest, ResourceResult) { contents2.text = "Second resource content"; contents2.mimeType = "application/json"; - ResourceResult result; + ResourcesReadResult result; result.contents = {contents1, contents2}; - llvm::Expected<ResourceResult> deserialized_result = roundtripJSON(result); + llvm::Expected<ResourcesReadResult> deserialized_result = + roundtripJSON(result); ASSERT_THAT_EXPECTED(deserialized_result, llvm::Succeeded()); ASSERT_EQ(result.contents.size(), deserialized_result->contents.size()); @@ -297,9 +298,10 @@ TEST(ProtocolMCPTest, ResourceResult) { } TEST(ProtocolMCPTest, ResourceResultEmpty) { - ResourceResult result; + ResourcesReadResult result; - llvm::Expected<ResourceResult> deserialized_result = roundtripJSON(result); + llvm::Expected<ResourcesReadResult> deserialized_result = + roundtripJSON(result); ASSERT_THAT_EXPECTED(deserialized_result, llvm::Succeeded()); EXPECT_TRUE(deserialized_result->contents.empty()); diff --git a/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp b/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp index 83a42bfb6970c..91c47c2229320 100644 --- a/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp +++ b/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp @@ -19,6 +19,7 @@ #include "lldb/Host/MainLoopBase.h" #include "lldb/Host/Socket.h" #include "lldb/Host/common/TCPSocket.h" +#include "lldb/Protocol/MCP/Binder.h" #include "lldb/Protocol/MCP/MCPError.h" #include "lldb/Protocol/MCP/Protocol.h" #include "llvm/Support/Error.h" @@ -36,18 +37,34 @@ using namespace lldb_private; using namespace lldb_protocol::mcp; using testing::_; +namespace lldb_protocol::mcp { +void PrintTo(const Request &req, std::ostream *os) { + *os << formatv("{0}", toJSON(req)).str(); +} +void PrintTo(const Response &resp, std::ostream *os) { + *os << formatv("{0}", toJSON(resp)).str(); +} +void PrintTo(const Notification ¬e, std::ostream *os) { + *os << formatv("{0}", toJSON(note)).str(); +} +void PrintTo(const Message &message, std::ostream *os) { + return std::visit([os](auto &&message) { return PrintTo(message, os); }, + message); +} +} // namespace lldb_protocol::mcp + namespace { class TestProtocolServerMCP : public lldb_private::mcp::ProtocolServerMCP { public: using ProtocolServerMCP::GetSocket; using ProtocolServerMCP::ProtocolServerMCP; - using ExtendCallback = - std::function<void(lldb_protocol::mcp::Server &server)>; + using ExtendCallback = std::function<void( + lldb_protocol::mcp::Server &server, lldb_protocol::mcp::Binder &binder)>; - virtual void Extend(lldb_protocol::mcp::Server &server) const override { + void Extend(lldb_protocol::mcp::Server &server) override { if (m_extend_callback) - m_extend_callback(server); + m_extend_callback(server, server.GetBinder()); }; void Extend(ExtendCallback callback) { m_extend_callback = callback; } @@ -55,7 +72,7 @@ class TestProtocolServerMCP : public lldb_private::mcp::ProtocolServerMCP { ExtendCallback m_extend_callback; }; -using Message = typename Transport<Request, Response, Notification>::Message; +using Message = typename lldb_protocol::mcp::Transport::Message; class TestJSONTransport final : public lldb_private::JSONRPCTransport<Request, Response, Notification> { @@ -74,7 +91,8 @@ class TestTool : public Tool { public: using Tool::Tool; - llvm::Expected<TextResult> Call(const ToolArguments &args) override { + void Call(const ToolArguments &args, + Callback<void(llvm::Expected<ToolsCallResult>)> reply) override { std::string argument; if (const json::Object *args_obj = std::get<json::Value>(args).getAsObject()) { @@ -83,9 +101,9 @@ class TestTool : public Tool { } } - TextResult text_result; + ToolsCallResult text_result; text_result.content.emplace_back(TextContent{{argument}}); - return text_result; + reply(text_result); } }; @@ -105,7 +123,7 @@ class TestResourceProvider : public ResourceProvider { return resources; } - llvm::Expected<ResourceResult> + llvm::Expected<ResourcesReadResult> ReadResource(llvm::StringRef uri) const override { if (uri != "lldb://foo/bar") return llvm::make_error<UnsupportedURI>(uri.str()); @@ -115,7 +133,7 @@ class TestResourceProvider : public ResourceProvider { contents.mimeType = "application/json"; contents.text = "foobar"; - ResourceResult result; + ResourcesReadResult result; result.contents.push_back(contents); return result; } @@ -126,8 +144,9 @@ class ErrorTool : public Tool { public: using Tool::Tool; - llvm::Expected<TextResult> Call(const ToolArguments &args) override { - return llvm::createStringError("error"); + void Call(const ToolArguments &args, + Callback<void(llvm::Expected<ToolsCallResult>)> reply) override { + reply(llvm::createStringError("error")); } }; @@ -136,11 +155,12 @@ class FailTool : public Tool { public: using Tool::Tool; - llvm::Expected<TextResult> Call(const ToolArguments &args) override { - TextResult text_result; + void Call(const ToolArguments &args, + Callback<void(llvm::Expected<ToolsCallResult>)> reply) override { + ToolsCallResult text_result; text_result.content.emplace_back(TextContent{{"failed"}}); text_result.isError = true; - return text_result; + reply(text_result); } }; @@ -191,7 +211,7 @@ class ProtocolServerMCPTest : public ::testing::Test { connection.protocol = Socket::SocketProtocol::ProtocolTcp; connection.name = llvm::formatv("{0}:0", k_localhost).str(); m_server_up = std::make_unique<TestProtocolServerMCP>(); - m_server_up->Extend([&](auto &server) { + m_server_up->Extend([&](auto &server, Binder &binder) { server.AddTool(std::make_unique<TestTool>("test", "test tool")); server.AddResourceProvider(std::make_unique<TestResourceProvider>()); }); @@ -225,7 +245,7 @@ TEST_F(ProtocolServerMCPTest, Initialization) { llvm::StringLiteral request = R"json({"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"lldb-unit","version":"0.1.0"}},"jsonrpc":"2.0","id":1})json"; llvm::StringLiteral response = - R"json({"id":1,"jsonrpc":"2.0","result":{"capabilities":{"resources":{"listChanged":false,"subscribe":false},"tools":{"listChanged":true}},"protocolVersion":"2024-11-05","serverInfo":{"name":"lldb-mcp","version":"0.1.0"}}})json"; + R"json({"id":1,"jsonrpc":"2.0","result":{"capabilities":{"resources":{"listChanged":true},"tools":{"listChanged":true}},"protocolVersion":"2024-11-05","serverInfo":{"name":"lldb-mcp","version":"0.1.0"}}})json"; ASSERT_THAT_ERROR(Write(request), Succeeded()); llvm::Expected<Response> expected_resp = json::parse<Response>(response); @@ -271,7 +291,7 @@ TEST_F(ProtocolServerMCPTest, ToolsCall) { llvm::StringLiteral request = R"json({"method":"tools/call","params":{"name":"test","arguments":{"arguments":"foo","debugger_id":0}},"jsonrpc":"2.0","id":11})json"; llvm::StringLiteral response = - R"json({"id":11,"jsonrpc":"2.0","result":{"content":[{"text":"foo","type":"text"}],"isError":false}})json"; + R"json({"id":11,"jsonrpc":"2.0","result":{"content":[{"text":"foo","type":"text"}]}})json"; ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); llvm::Expected<Response> expected_resp = json::parse<Response>(response); @@ -281,7 +301,7 @@ TEST_F(ProtocolServerMCPTest, ToolsCall) { } TEST_F(ProtocolServerMCPTest, ToolsCallError) { - m_server_up->Extend([&](auto &server) { + m_server_up->Extend([&](auto &server, auto &binder) { server.AddTool(std::make_unique<ErrorTool>("error", "error tool")); }); @@ -298,7 +318,7 @@ TEST_F(ProtocolServerMCPTest, ToolsCallError) { } TEST_F(ProtocolServerMCPTest, ToolsCallFail) { - m_server_up->Extend([&](auto &server) { + m_server_up->Extend([&](auto &server, auto &binder) { server.AddTool(std::make_unique<FailTool>("fail", "fail tool")); }); @@ -319,15 +339,15 @@ TEST_F(ProtocolServerMCPTest, NotificationInitialized) { std::condition_variable cv; std::mutex mutex; - m_server_up->Extend([&](auto &server) { - server.AddNotificationHandler("notifications/initialized", - [&](const Notification ¬ification) { - { - std::lock_guard<std::mutex> lock(mutex); - handler_called = true; - } - cv.notify_all(); - }); + m_server_up->Extend([&](auto &server, auto &binder) { + binder.template notification<Void>( + "notifications/initialized", [&](const Void &) { + { + std::lock_guard<std::mutex> lock(mutex); + handler_called = true; + } + cv.notify_all(); + }); }); llvm::StringLiteral request = R"json({"method":"notifications/initialized","jsonrpc":"2.0"})json"; _______________________________________________ lldb-commits mailing list lldb-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/lldb-commits