Author: John Harrison Date: 2025-08-26T15:09:36-07:00 New Revision: a67257bbfb5bce5a21f21d7e78049cfcbb283e33
URL: https://github.com/llvm/llvm-project/commit/a67257bbfb5bce5a21f21d7e78049cfcbb283e33 DIFF: https://github.com/llvm/llvm-project/commit/a67257bbfb5bce5a21f21d7e78049cfcbb283e33.diff LOG: [lldb] Adding structured types for existing MCP calls. (#155460) This adds or renames existing types to match the names of the types on https://modelcontextprotocol.io/specification/2025-06-18/schema for the existing calls. The new types are used in the unit tests and server implementation to remove the need for crafting various `llvm::json::Object` values by hand. Added: lldb/unittests/Protocol/ProtocolMCPTestUtilities.h Modified: lldb/include/lldb/Protocol/MCP/Protocol.h lldb/include/lldb/Protocol/MCP/Resource.h lldb/include/lldb/Protocol/MCP/Server.h lldb/include/lldb/Protocol/MCP/Tool.h lldb/source/Plugins/Protocol/MCP/Resource.cpp lldb/source/Plugins/Protocol/MCP/Resource.h lldb/source/Plugins/Protocol/MCP/Tool.cpp lldb/source/Plugins/Protocol/MCP/Tool.h lldb/source/Protocol/MCP/Protocol.cpp lldb/source/Protocol/MCP/Server.cpp lldb/unittests/Protocol/ProtocolMCPServerTest.cpp lldb/unittests/Protocol/ProtocolMCPTest.cpp Removed: ################################################################################ diff --git a/lldb/include/lldb/Protocol/MCP/Protocol.h b/lldb/include/lldb/Protocol/MCP/Protocol.h index 49f94902217558..6e1ffcbe1f3e33 100644 --- a/lldb/include/lldb/Protocol/MCP/Protocol.h +++ b/lldb/include/lldb/Protocol/MCP/Protocol.h @@ -18,6 +18,7 @@ #include <optional> #include <string> #include <variant> +#include <vector> namespace lldb_protocol::mcp { @@ -38,11 +39,24 @@ struct Request { /// The method's params. std::optional<llvm::json::Value> params; }; - llvm::json::Value toJSON(const Request &); bool fromJSON(const llvm::json::Value &, Request &, llvm::json::Path); bool operator==(const Request &, const Request &); +enum ErrorCode : signed { + /// Invalid JSON was received by the server. An error occurred on the server + /// while parsing the JSON text. + eErrorCodeParseError = -32700, + /// The JSON sent is not a valid Request object. + eErrorCodeInvalidRequest = -32600, + /// The method does not exist / is not available. + eErrorCodeMethodNotFound = -32601, + /// Invalid method parameter(s). + eErrorCodeInvalidParams = -32602, + /// Internal JSON-RPC error. + eErrorCodeInternalError = -32603, +}; + struct Error { /// The error type that occurred. int64_t code = 0; @@ -52,9 +66,8 @@ struct Error { /// Additional information about the error. The value of this member is /// defined by the sender (e.g. detailed error information, nested errors /// etc.). - std::optional<llvm::json::Value> data; + std::optional<llvm::json::Value> data = std::nullopt; }; - llvm::json::Value toJSON(const Error &); bool fromJSON(const llvm::json::Value &, Error &, llvm::json::Path); bool operator==(const Error &, const Error &); @@ -67,7 +80,6 @@ struct Response { /// response. std::variant<Error, llvm::json::Value> result; }; - llvm::json::Value toJSON(const Response &); bool fromJSON(const llvm::json::Value &, Response &, llvm::json::Path); bool operator==(const Response &, const Response &); @@ -79,7 +91,6 @@ struct Notification { /// The notification's params. std::optional<llvm::json::Value> params; }; - llvm::json::Value toJSON(const Notification &); bool fromJSON(const llvm::json::Value &, Notification &, llvm::json::Path); bool operator==(const Notification &, const Notification &); @@ -90,45 +101,9 @@ using Message = std::variant<Request, Response, Notification>; // not force it to be checked early here. static_assert(std::is_convertible_v<Message, Message>, "Message is not convertible to itself"); - bool fromJSON(const llvm::json::Value &, Message &, llvm::json::Path); llvm::json::Value toJSON(const Message &); -struct ToolCapability { - /// Whether this server supports notifications for changes to the tool list. - bool listChanged = false; -}; - -llvm::json::Value toJSON(const ToolCapability &); -bool fromJSON(const llvm::json::Value &, ToolCapability &, llvm::json::Path); - -struct ResourceCapability { - /// Whether this server supports notifications for changes to the resources - /// list. - bool listChanged = false; - - /// Whether subscriptions are supported. - bool subscribe = false; -}; - -llvm::json::Value toJSON(const ResourceCapability &); -bool fromJSON(const llvm::json::Value &, ResourceCapability &, - llvm::json::Path); - -/// Capabilities that a server may support. Known capabilities are defined here, -/// in this schema, but this is not a closed set: any server can define its own, -/// additional capabilities. -struct Capabilities { - /// Tool capabilities of the server. - ToolCapability tools; - - /// Resource capabilities of the server. - ResourceCapability resources; -}; - -llvm::json::Value toJSON(const Capabilities &); -bool fromJSON(const llvm::json::Value &, Capabilities &, llvm::json::Path); - /// A known resource that the server is capable of reading. struct Resource { /// The URI of this resource. @@ -138,17 +113,25 @@ struct Resource { std::string name; /// A description of what this resource represents. - std::string description; + std::string description = ""; /// The MIME type of this resource, if known. - std::string mimeType; + std::string mimeType = ""; }; llvm::json::Value toJSON(const Resource &); bool fromJSON(const llvm::json::Value &, Resource &, llvm::json::Path); +/// The server’s response to a resources/list request from the client. +struct ListResourcesResult { + std::vector<Resource> resources; +}; +llvm::json::Value toJSON(const ListResourcesResult &); +bool fromJSON(const llvm::json::Value &, ListResourcesResult &, + llvm::json::Path); + /// The contents of a specific resource or sub-resource. -struct ResourceContents { +struct TextResourceContents { /// The URI of this resource. std::string uri; @@ -160,34 +143,37 @@ struct ResourceContents { std::string mimeType; }; -llvm::json::Value toJSON(const ResourceContents &); -bool fromJSON(const llvm::json::Value &, ResourceContents &, llvm::json::Path); +llvm::json::Value toJSON(const TextResourceContents &); +bool fromJSON(const llvm::json::Value &, TextResourceContents &, + llvm::json::Path); -/// The server's response to a resources/read request from the client. -struct ResourceResult { - std::vector<ResourceContents> contents; +/// Sent from the client to the server, to read a specific resource URI. +struct ReadResourceParams { + /// The URI of the resource to read. The URI can use any protocol; it is up to + /// the server how to interpret it. + std::string uri; }; +llvm::json::Value toJSON(const ReadResourceParams &); +bool fromJSON(const llvm::json::Value &, ReadResourceParams &, + llvm::json::Path); -llvm::json::Value toJSON(const ResourceResult &); -bool fromJSON(const llvm::json::Value &, ResourceResult &, llvm::json::Path); +/// The server's response to a resources/read request from the client. +struct ReadResourceResult { + std::vector<TextResourceContents> contents; +}; +llvm::json::Value toJSON(const ReadResourceResult &); +bool fromJSON(const llvm::json::Value &, ReadResourceResult &, + llvm::json::Path); /// Text provided to or from an LLM. struct TextContent { /// The text content of the message. std::string text; }; - llvm::json::Value toJSON(const TextContent &); bool fromJSON(const llvm::json::Value &, TextContent &, llvm::json::Path); -struct TextResult { - std::vector<TextContent> content; - bool isError = false; -}; - -llvm::json::Value toJSON(const TextResult &); -bool fromJSON(const llvm::json::Value &, TextResult &, llvm::json::Path); - +/// Definition for a tool the client can call. struct ToolDefinition { /// Unique identifier for the tool. std::string name; @@ -198,12 +184,144 @@ struct ToolDefinition { // JSON Schema for the tool's parameters. std::optional<llvm::json::Value> inputSchema; }; - llvm::json::Value toJSON(const ToolDefinition &); bool fromJSON(const llvm::json::Value &, ToolDefinition &, llvm::json::Path); using ToolArguments = std::variant<std::monostate, llvm::json::Value>; +/// Describes the name and version of an MCP implementation, with an optional +/// title for UI representation. +struct Implementation { + /// Intended for programmatic or logical use, but used as a display name in + /// past specs or fallback (if title isn’t present). + std::string name; + + std::string version; + + /// Intended for UI and end-user contexts — optimized to be human-readable and + /// easily understood, even by those unfamiliar with domain-specific + /// terminology. + /// + /// If not provided, the name should be used for display (except for Tool, + /// where annotations.title should be given precedence over using name, if + /// present). + std::string title = ""; +}; +llvm::json::Value toJSON(const Implementation &); +bool fromJSON(const llvm::json::Value &, Implementation &, llvm::json::Path); + +/// Capabilities a client may support. Known capabilities are defined here, in +/// this schema, but this is not a closed set: any client can define its own, +/// additional capabilities. +struct ClientCapabilities {}; +llvm::json::Value toJSON(const ClientCapabilities &); +bool fromJSON(const llvm::json::Value &, ClientCapabilities &, + llvm::json::Path); + +/// Capabilities that a server may support. Known capabilities are defined here, +/// in this schema, but this is not a closed set: any server can define its own, +/// additional capabilities. +struct ServerCapabilities { + bool supportsToolsList = false; + bool supportsResourcesList = false; + bool supportsResourcesSubscribe = false; + + /// Utilities. + bool supportsCompletions = false; + bool supportsLogging = false; +}; +llvm::json::Value toJSON(const ServerCapabilities &); +bool fromJSON(const llvm::json::Value &, ServerCapabilities &, + llvm::json::Path); + +/// Initialization + +/// This request is sent from the client to the server when it first connects, +/// asking it to begin initialization. +struct InitializeParams { + /// The latest version of the Model Context Protocol that the client supports. + /// The client MAY decide to support older versions as well. + std::string protocolVersion; + + ClientCapabilities capabilities; + + Implementation clientInfo; +}; +llvm::json::Value toJSON(const InitializeParams &); +bool fromJSON(const llvm::json::Value &, InitializeParams &, llvm::json::Path); + +/// After receiving an initialize request from the client, the server sends this +/// response. +struct InitializeResult { + /// The version of the Model Context Protocol that the server wants to use. + /// This may not match the version that the client requested. If the client + /// cannot support this version, it MUST disconnect. + std::string protocolVersion; + + ServerCapabilities capabilities; + Implementation serverInfo; + + /// Instructions describing how to use the server and its features. + /// + /// This can be used by clients to improve the LLM's understanding of + /// available tools, resources, etc. It can be thought of like a "hint" to the + /// model. For example, this information MAY be added to the system prompt. + std::string instructions = ""; +}; +llvm::json::Value toJSON(const InitializeResult &); +bool fromJSON(const llvm::json::Value &, InitializeResult &, llvm::json::Path); + +/// Special case parameter or result that has no value. +using Void = std::monostate; +llvm::json::Value toJSON(const Void &); +bool fromJSON(const llvm::json::Value &, Void &, llvm::json::Path); + +/// The server's response to a `tools/list` request from the client. +struct ListToolsResult { + std::vector<ToolDefinition> tools; +}; +llvm::json::Value toJSON(const ListToolsResult &); +bool fromJSON(const llvm::json::Value &, ListToolsResult &, llvm::json::Path); + +/// Supported content types, currently only TextContent, but the spec includes +/// additional content types. +using ContentBlock = TextContent; + +/// Used by the client to invoke a tool provided by the server. +struct CallToolParams { + std::string name; + std::optional<llvm::json::Value> arguments; +}; +llvm::json::Value toJSON(const CallToolParams &); +bool fromJSON(const llvm::json::Value &, CallToolParams &, llvm::json::Path); + +/// The server’s response to a tool call. +struct CallToolResult { + /// A list of content objects that represent the unstructured result of the + /// tool call. + std::vector<ContentBlock> content; + + /// Whether the tool call ended in an error. + /// + /// If not set, this is assumed to be false (the call was successful). + /// + /// Any errors that originate from the tool SHOULD be reported inside the + /// result object, with `isError` set to true, not as an MCP protocol-level + /// error response. Otherwise, the LLM would not be able to see that an error + /// occurred and self-correct. + /// + /// However, any errors in finding the tool, an error indicating that the + /// server does not support tool calls, or any other exceptional conditions, + /// should be reported as an MCP error response. + bool isError = false; + + /// An optional JSON object that represents the structured result of the tool + /// call. + std::optional<llvm::json::Value> structuredContent = std::nullopt; +}; +llvm::json::Value toJSON(const CallToolResult &); +bool fromJSON(const llvm::json::Value &, CallToolResult &, llvm::json::Path); + } // namespace lldb_protocol::mcp #endif diff --git a/lldb/include/lldb/Protocol/MCP/Resource.h b/lldb/include/lldb/Protocol/MCP/Resource.h index 4835d340cd4c6a..158cffc71ea101 100644 --- a/lldb/include/lldb/Protocol/MCP/Resource.h +++ b/lldb/include/lldb/Protocol/MCP/Resource.h @@ -20,7 +20,7 @@ class ResourceProvider { virtual ~ResourceProvider() = default; virtual std::vector<lldb_protocol::mcp::Resource> GetResources() const = 0; - virtual llvm::Expected<lldb_protocol::mcp::ResourceResult> + virtual llvm::Expected<lldb_protocol::mcp::ReadResourceResult> ReadResource(llvm::StringRef uri) const = 0; }; diff --git a/lldb/include/lldb/Protocol/MCP/Server.h b/lldb/include/lldb/Protocol/MCP/Server.h index 2b9e919329752c..aa5714e45755e5 100644 --- a/lldb/include/lldb/Protocol/MCP/Server.h +++ b/lldb/include/lldb/Protocol/MCP/Server.h @@ -58,7 +58,7 @@ class Server : public MCPTransport::MessageHandler { llvm::Error Run(); protected: - Capabilities GetCapabilities(); + ServerCapabilities GetCapabilities(); using RequestHandler = std::function<llvm::Expected<Response>(const Request &)>; diff --git a/lldb/include/lldb/Protocol/MCP/Tool.h b/lldb/include/lldb/Protocol/MCP/Tool.h index 96669d1357166a..6c9f05161f8e7b 100644 --- a/lldb/include/lldb/Protocol/MCP/Tool.h +++ b/lldb/include/lldb/Protocol/MCP/Tool.h @@ -10,6 +10,7 @@ #define LLDB_PROTOCOL_MCP_TOOL_H #include "lldb/Protocol/MCP/Protocol.h" +#include "llvm/Support/Error.h" #include "llvm/Support/JSON.h" #include <string> @@ -20,7 +21,7 @@ class Tool { Tool(std::string name, std::string description); virtual ~Tool() = default; - virtual llvm::Expected<lldb_protocol::mcp::TextResult> + virtual llvm::Expected<lldb_protocol::mcp::CallToolResult> Call(const lldb_protocol::mcp::ToolArguments &args) = 0; virtual std::optional<llvm::json::Value> GetSchema() const { diff --git a/lldb/source/Plugins/Protocol/MCP/Resource.cpp b/lldb/source/Plugins/Protocol/MCP/Resource.cpp index e94d2cdd65e07a..581424510d4cf4 100644 --- a/lldb/source/Plugins/Protocol/MCP/Resource.cpp +++ b/lldb/source/Plugins/Protocol/MCP/Resource.cpp @@ -8,7 +8,6 @@ #include "lldb/Core/Debugger.h" #include "lldb/Core/Module.h" #include "lldb/Protocol/MCP/MCPError.h" -#include "lldb/Target/Platform.h" using namespace lldb_private; using namespace lldb_private::mcp; @@ -124,7 +123,7 @@ DebuggerResourceProvider::GetResources() const { return resources; } -llvm::Expected<lldb_protocol::mcp::ResourceResult> +llvm::Expected<lldb_protocol::mcp::ReadResourceResult> DebuggerResourceProvider::ReadResource(llvm::StringRef uri) const { auto [protocol, path] = uri.split("://"); @@ -161,7 +160,7 @@ DebuggerResourceProvider::ReadResource(llvm::StringRef uri) const { return ReadDebuggerResource(uri, debugger_idx); } -llvm::Expected<lldb_protocol::mcp::ResourceResult> +llvm::Expected<lldb_protocol::mcp::ReadResourceResult> DebuggerResourceProvider::ReadDebuggerResource(llvm::StringRef uri, lldb::user_id_t debugger_id) { lldb::DebuggerSP debugger_sp = Debugger::FindDebuggerWithID(debugger_id); @@ -173,17 +172,17 @@ DebuggerResourceProvider::ReadDebuggerResource(llvm::StringRef uri, debugger_resource.name = debugger_sp->GetInstanceName(); debugger_resource.num_targets = debugger_sp->GetTargetList().GetNumTargets(); - lldb_protocol::mcp::ResourceContents contents; + lldb_protocol::mcp::TextResourceContents contents; contents.uri = uri; contents.mimeType = kMimeTypeJSON; contents.text = llvm::formatv("{0}", toJSON(debugger_resource)); - lldb_protocol::mcp::ResourceResult result; + lldb_protocol::mcp::ReadResourceResult result; result.contents.push_back(contents); return result; } -llvm::Expected<lldb_protocol::mcp::ResourceResult> +llvm::Expected<lldb_protocol::mcp::ReadResourceResult> DebuggerResourceProvider::ReadTargetResource(llvm::StringRef uri, lldb::user_id_t debugger_id, size_t target_idx) { @@ -209,12 +208,12 @@ DebuggerResourceProvider::ReadTargetResource(llvm::StringRef uri, if (lldb::PlatformSP platform_sp = target_sp->GetPlatform()) target_resource.platform = platform_sp->GetName(); - lldb_protocol::mcp::ResourceContents contents; + lldb_protocol::mcp::TextResourceContents contents; contents.uri = uri; contents.mimeType = kMimeTypeJSON; contents.text = llvm::formatv("{0}", toJSON(target_resource)); - lldb_protocol::mcp::ResourceResult result; + lldb_protocol::mcp::ReadResourceResult result; result.contents.push_back(contents); return result; } diff --git a/lldb/source/Plugins/Protocol/MCP/Resource.h b/lldb/source/Plugins/Protocol/MCP/Resource.h index e2382a74f796b2..0c6576602905e5 100644 --- a/lldb/source/Plugins/Protocol/MCP/Resource.h +++ b/lldb/source/Plugins/Protocol/MCP/Resource.h @@ -11,7 +11,11 @@ #include "lldb/Protocol/MCP/Protocol.h" #include "lldb/Protocol/MCP/Resource.h" -#include "lldb/lldb-private.h" +#include "lldb/lldb-forward.h" +#include "lldb/lldb-types.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Error.h" +#include <cstddef> #include <vector> namespace lldb_private::mcp { @@ -21,9 +25,8 @@ class DebuggerResourceProvider : public lldb_protocol::mcp::ResourceProvider { using ResourceProvider::ResourceProvider; virtual ~DebuggerResourceProvider() = default; - virtual std::vector<lldb_protocol::mcp::Resource> - GetResources() const override; - virtual llvm::Expected<lldb_protocol::mcp::ResourceResult> + std::vector<lldb_protocol::mcp::Resource> GetResources() const override; + llvm::Expected<lldb_protocol::mcp::ReadResourceResult> ReadResource(llvm::StringRef uri) const override; private: @@ -31,9 +34,9 @@ class DebuggerResourceProvider : public lldb_protocol::mcp::ResourceProvider { static lldb_protocol::mcp::Resource GetTargetResource(size_t target_idx, Target &target); - static llvm::Expected<lldb_protocol::mcp::ResourceResult> + static llvm::Expected<lldb_protocol::mcp::ReadResourceResult> ReadDebuggerResource(llvm::StringRef uri, lldb::user_id_t debugger_id); - static llvm::Expected<lldb_protocol::mcp::ResourceResult> + static llvm::Expected<lldb_protocol::mcp::ReadResourceResult> ReadTargetResource(llvm::StringRef uri, lldb::user_id_t debugger_id, size_t target_idx); }; diff --git a/lldb/source/Plugins/Protocol/MCP/Tool.cpp b/lldb/source/Plugins/Protocol/MCP/Tool.cpp index 143470702a6fdc..2f451bf76e81de 100644 --- a/lldb/source/Plugins/Protocol/MCP/Tool.cpp +++ b/lldb/source/Plugins/Protocol/MCP/Tool.cpp @@ -7,9 +7,9 @@ //===----------------------------------------------------------------------===// #include "Tool.h" -#include "lldb/Core/Module.h" #include "lldb/Interpreter/CommandInterpreter.h" #include "lldb/Interpreter/CommandReturnObject.h" +#include "lldb/Protocol/MCP/Protocol.h" using namespace lldb_private; using namespace lldb_protocol; @@ -29,10 +29,10 @@ bool fromJSON(const llvm::json::Value &V, CommandToolArguments &A, O.mapOptional("arguments", A.arguments); } -/// Helper function to create a TextResult from a string output. -static lldb_protocol::mcp::TextResult createTextResult(std::string output, - bool is_error = false) { - lldb_protocol::mcp::TextResult text_result; +/// Helper function to create a CallToolResult from a string output. +static lldb_protocol::mcp::CallToolResult +createTextResult(std::string output, bool is_error = false) { + lldb_protocol::mcp::CallToolResult text_result; text_result.content.emplace_back( lldb_protocol::mcp::TextContent{{std::move(output)}}); text_result.isError = is_error; @@ -41,7 +41,7 @@ static lldb_protocol::mcp::TextResult createTextResult(std::string output, } // namespace -llvm::Expected<lldb_protocol::mcp::TextResult> +llvm::Expected<lldb_protocol::mcp::CallToolResult> CommandTool::Call(const lldb_protocol::mcp::ToolArguments &args) { if (!std::holds_alternative<json::Value>(args)) return createStringError("CommandTool requires arguments"); diff --git a/lldb/source/Plugins/Protocol/MCP/Tool.h b/lldb/source/Plugins/Protocol/MCP/Tool.h index b7b1756eb38d7f..1886525b9168f5 100644 --- a/lldb/source/Plugins/Protocol/MCP/Tool.h +++ b/lldb/source/Plugins/Protocol/MCP/Tool.h @@ -9,11 +9,11 @@ #ifndef LLDB_PLUGINS_PROTOCOL_MCP_TOOL_H #define LLDB_PLUGINS_PROTOCOL_MCP_TOOL_H -#include "lldb/Core/Debugger.h" #include "lldb/Protocol/MCP/Protocol.h" #include "lldb/Protocol/MCP/Tool.h" +#include "llvm/Support/Error.h" #include "llvm/Support/JSON.h" -#include <string> +#include <optional> namespace lldb_private::mcp { @@ -22,10 +22,10 @@ class CommandTool : public lldb_protocol::mcp::Tool { using lldb_protocol::mcp::Tool::Tool; ~CommandTool() = default; - virtual llvm::Expected<lldb_protocol::mcp::TextResult> + llvm::Expected<lldb_protocol::mcp::CallToolResult> Call(const lldb_protocol::mcp::ToolArguments &args) override; - virtual std::optional<llvm::json::Value> GetSchema() const override; + std::optional<llvm::json::Value> GetSchema() const override; }; } // namespace lldb_private::mcp diff --git a/lldb/source/Protocol/MCP/Protocol.cpp b/lldb/source/Protocol/MCP/Protocol.cpp index 65ddfaee70160b..0988f456adc263 100644 --- a/lldb/source/Protocol/MCP/Protocol.cpp +++ b/lldb/source/Protocol/MCP/Protocol.cpp @@ -167,32 +167,6 @@ bool operator==(const Notification &a, const Notification &b) { return a.method == b.method && a.params == b.params; } -llvm::json::Value toJSON(const ToolCapability &TC) { - return llvm::json::Object{{"listChanged", TC.listChanged}}; -} - -bool fromJSON(const llvm::json::Value &V, ToolCapability &TC, - llvm::json::Path P) { - llvm::json::ObjectMapper O(V, P); - return O && O.map("listChanged", TC.listChanged); -} - -llvm::json::Value toJSON(const ResourceCapability &RC) { - return llvm::json::Object{{"listChanged", RC.listChanged}, - {"subscribe", RC.subscribe}}; -} - -bool fromJSON(const llvm::json::Value &V, ResourceCapability &RC, - llvm::json::Path P) { - llvm::json::ObjectMapper O(V, P); - return O && O.map("listChanged", RC.listChanged) && - O.map("subscribe", RC.subscribe); -} - -llvm::json::Value toJSON(const Capabilities &C) { - return llvm::json::Object{{"tools", C.tools}, {"resources", C.resources}}; -} - bool fromJSON(const llvm::json::Value &V, Resource &R, llvm::json::Path P) { llvm::json::ObjectMapper O(V, P); return O && O.map("uri", R.uri) && O.map("name", R.name) && @@ -209,30 +183,25 @@ llvm::json::Value toJSON(const Resource &R) { return Result; } -bool fromJSON(const llvm::json::Value &V, Capabilities &C, llvm::json::Path P) { - llvm::json::ObjectMapper O(V, P); - return O && O.map("tools", C.tools); -} - -llvm::json::Value toJSON(const ResourceContents &RC) { +llvm::json::Value toJSON(const TextResourceContents &RC) { llvm::json::Object Result{{"uri", RC.uri}, {"text", RC.text}}; if (!RC.mimeType.empty()) Result.insert({"mimeType", RC.mimeType}); return Result; } -bool fromJSON(const llvm::json::Value &V, ResourceContents &RC, +bool fromJSON(const llvm::json::Value &V, TextResourceContents &RC, llvm::json::Path P) { llvm::json::ObjectMapper O(V, P); return O && O.map("uri", RC.uri) && O.map("text", RC.text) && O.mapOptional("mimeType", RC.mimeType); } -llvm::json::Value toJSON(const ResourceResult &RR) { +llvm::json::Value toJSON(const ReadResourceResult &RR) { return llvm::json::Object{{"contents", RR.contents}}; } -bool fromJSON(const llvm::json::Value &V, ResourceResult &RR, +bool fromJSON(const llvm::json::Value &V, ReadResourceResult &RR, llvm::json::Path P) { llvm::json::ObjectMapper O(V, P); return O && O.map("contents", RR.contents); @@ -247,15 +216,6 @@ bool fromJSON(const llvm::json::Value &V, TextContent &TC, llvm::json::Path P) { return O && O.map("text", TC.text); } -llvm::json::Value toJSON(const TextResult &TR) { - return llvm::json::Object{{"content", TR.content}, {"isError", TR.isError}}; -} - -bool fromJSON(const llvm::json::Value &V, TextResult &TR, llvm::json::Path P) { - llvm::json::ObjectMapper O(V, P); - return O && O.map("content", TR.content) && O.map("isError", TR.isError); -} - llvm::json::Value toJSON(const ToolDefinition &TD) { llvm::json::Object Result{{"name", TD.name}}; if (!TD.description.empty()) @@ -325,4 +285,159 @@ bool fromJSON(const llvm::json::Value &V, Message &M, llvm::json::Path P) { return false; } +json::Value toJSON(const Implementation &I) { + json::Object result{{"name", I.name}, {"version", I.version}}; + + if (!I.title.empty()) + result.insert({"title", I.title}); + + return result; +} + +bool fromJSON(const json::Value &V, Implementation &I, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("name", I.name) && O.mapOptional("title", I.title) && + O.mapOptional("version", I.version); +} + +json::Value toJSON(const ClientCapabilities &C) { return json::Object{}; } + +bool fromJSON(const json::Value &, ClientCapabilities &, json::Path) { + return true; +} + +json::Value toJSON(const ServerCapabilities &C) { + json::Object result{}; + + if (C.supportsToolsList) + result.insert({"tools", json::Object{{"listChanged", true}}}); + + if (C.supportsResourcesList || C.supportsResourcesSubscribe) { + json::Object resources; + if (C.supportsResourcesList) + resources.insert({"listChanged", true}); + if (C.supportsResourcesSubscribe) + resources.insert({"subscribe", true}); + result.insert({"resources", std::move(resources)}); + } + + if (C.supportsCompletions) + result.insert({"completions", json::Object{}}); + + if (C.supportsLogging) + result.insert({"logging", json::Object{}}); + + return result; +} + +bool fromJSON(const json::Value &V, ServerCapabilities &C, json::Path P) { + const json::Object *O = V.getAsObject(); + if (!O) { + P.report("expected object"); + return false; + } + + if (O->find("tools") != O->end()) + C.supportsToolsList = true; + + return true; +} + +json::Value toJSON(const InitializeParams &P) { + return json::Object{ + {"protocolVersion", P.protocolVersion}, + {"capabilities", P.capabilities}, + {"clientInfo", P.clientInfo}, + }; +} + +bool fromJSON(const json::Value &V, InitializeParams &I, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("protocolVersion", I.protocolVersion) && + O.map("capabilities", I.capabilities) && + O.map("clientInfo", I.clientInfo); +} + +json::Value toJSON(const InitializeResult &R) { + json::Object result{{"protocolVersion", R.protocolVersion}, + {"capabilities", R.capabilities}, + {"serverInfo", R.serverInfo}}; + + if (!R.instructions.empty()) + result.insert({"instructions", R.instructions}); + + return result; +} + +bool fromJSON(const json::Value &V, InitializeResult &R, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("protocolVersion", R.protocolVersion) && + O.map("capabilities", R.capabilities) && + O.map("serverInfo", R.serverInfo) && + O.mapOptional("instructions", R.instructions); +} + +json::Value toJSON(const ListToolsResult &R) { + return json::Object{{"tools", R.tools}}; +} + +bool fromJSON(const json::Value &V, ListToolsResult &R, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("tools", R.tools); +} + +json::Value toJSON(const CallToolResult &R) { + json::Object result{{"content", R.content}}; + + if (R.isError) + result.insert({"isError", R.isError}); + if (R.structuredContent) + result.insert({"structuredContent", *R.structuredContent}); + + return result; +} + +bool fromJSON(const json::Value &V, CallToolResult &R, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("content", R.content) && + O.mapOptional("isError", R.isError) && + mapRaw(V, "structuredContent", R.structuredContent, P); +} + +json::Value toJSON(const CallToolParams &R) { + json::Object result{{"name", R.name}}; + + if (R.arguments) + result.insert({"arguments", *R.arguments}); + + return result; +} + +bool fromJSON(const json::Value &V, CallToolParams &R, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("name", R.name) && mapRaw(V, "arguments", R.arguments, P); +} + +json::Value toJSON(const ReadResourceParams &R) { + return json::Object{{"uri", R.uri}}; +} + +bool fromJSON(const json::Value &V, ReadResourceParams &R, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("uri", R.uri); +} + +json::Value toJSON(const ListResourcesResult &R) { + return json::Object{{"resources", R.resources}}; +} + +bool fromJSON(const json::Value &V, ListResourcesResult &R, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("resources", R.resources); +} + +json::Value toJSON(const Void &R) { return json::Object{}; } + +bool fromJSON(const json::Value &V, Void &R, json::Path P) { return true; } + } // namespace lldb_protocol::mcp diff --git a/lldb/source/Protocol/MCP/Server.cpp b/lldb/source/Protocol/MCP/Server.cpp index c1a6026b11090a..63c2d01d17922a 100644 --- a/lldb/source/Protocol/MCP/Server.cpp +++ b/lldb/source/Protocol/MCP/Server.cpp @@ -8,6 +8,8 @@ #include "lldb/Protocol/MCP/Server.h" #include "lldb/Protocol/MCP/MCPError.h" +#include "lldb/Protocol/MCP/Protocol.h" +#include "llvm/Support/JSON.h" using namespace lldb_protocol::mcp; using namespace llvm; @@ -79,22 +81,23 @@ void Server::AddNotificationHandler(llvm::StringRef method, llvm::Expected<Response> Server::InitializeHandler(const Request &request) { Response response; - response.result = llvm::json::Object{ - {"protocolVersion", mcp::kProtocolVersion}, - {"capabilities", GetCapabilities()}, - {"serverInfo", - llvm::json::Object{{"name", m_name}, {"version", m_version}}}}; + InitializeResult result; + result.protocolVersion = mcp::kProtocolVersion; + result.capabilities = GetCapabilities(); + result.serverInfo.name = m_name; + result.serverInfo.version = m_version; + response.result = std::move(result); return response; } llvm::Expected<Response> Server::ToolsListHandler(const Request &request) { Response response; - llvm::json::Array tools; + ListToolsResult result; for (const auto &tool : m_tools) - tools.emplace_back(toJSON(tool.second->GetDefinition())); + result.tools.emplace_back(tool.second->GetDefinition()); - response.result = llvm::json::Object{{"tools", std::move(tools)}}; + response.result = std::move(result); return response; } @@ -104,16 +107,12 @@ llvm::Expected<Response> Server::ToolsCallHandler(const Request &request) { if (!request.params) return llvm::createStringError("no tool parameters"); + CallToolParams params; + json::Path::Root root("params"); + if (!fromJSON(request.params, params, root)) + return root.getError(); - const json::Object *param_obj = request.params->getAsObject(); - if (!param_obj) - return llvm::createStringError("no tool parameters"); - - const json::Value *name = param_obj->get("name"); - if (!name) - return llvm::createStringError("no tool name"); - - llvm::StringRef tool_name = name->getAsString().value_or(""); + llvm::StringRef tool_name = params.name; if (tool_name.empty()) return llvm::createStringError("no tool name"); @@ -122,10 +121,10 @@ llvm::Expected<Response> Server::ToolsCallHandler(const Request &request) { return llvm::createStringError(llvm::formatv("no tool \"{0}\"", tool_name)); ToolArguments tool_args; - if (const json::Value *args = param_obj->get("arguments")) - tool_args = *args; + if (params.arguments) + tool_args = *params.arguments; - llvm::Expected<TextResult> text_result = it->second->Call(tool_args); + llvm::Expected<CallToolResult> text_result = it->second->Call(tool_args); if (!text_result) return text_result.takeError(); @@ -137,14 +136,13 @@ llvm::Expected<Response> Server::ToolsCallHandler(const Request &request) { llvm::Expected<Response> Server::ResourcesListHandler(const Request &request) { Response response; - llvm::json::Array resources; - + ListResourcesResult result; for (std::unique_ptr<ResourceProvider> &resource_provider_up : - m_resource_providers) { + m_resource_providers) for (const Resource &resource : resource_provider_up->GetResources()) - resources.push_back(resource); - } - response.result = llvm::json::Object{{"resources", std::move(resources)}}; + result.resources.push_back(resource); + + response.result = std::move(result); return response; } @@ -155,21 +153,18 @@ llvm::Expected<Response> Server::ResourcesReadHandler(const Request &request) { if (!request.params) return llvm::createStringError("no resource parameters"); - const json::Object *param_obj = request.params->getAsObject(); - if (!param_obj) - return llvm::createStringError("no resource parameters"); - - const json::Value *uri = param_obj->get("uri"); - if (!uri) - return llvm::createStringError("no resource uri"); + ReadResourceParams params; + json::Path::Root root("params"); + if (!fromJSON(request.params, params, root)) + return root.getError(); - llvm::StringRef uri_str = uri->getAsString().value_or(""); + llvm::StringRef uri_str = params.uri; if (uri_str.empty()) return llvm::createStringError("no resource uri"); for (std::unique_ptr<ResourceProvider> &resource_provider_up : m_resource_providers) { - llvm::Expected<ResourceResult> result = + llvm::Expected<ReadResourceResult> result = resource_provider_up->ReadResource(uri_str); if (result.errorIsA<UnsupportedURI>()) { llvm::consumeError(result.takeError()); @@ -188,12 +183,12 @@ llvm::Expected<Response> Server::ResourcesReadHandler(const Request &request) { MCPError::kResourceNotFound); } -Capabilities Server::GetCapabilities() { - lldb_protocol::mcp::Capabilities capabilities; - capabilities.tools.listChanged = true; +ServerCapabilities Server::GetCapabilities() { + lldb_protocol::mcp::ServerCapabilities capabilities; + capabilities.supportsToolsList = true; // FIXME: Support sending notifications when a debugger/target are // added/removed. - capabilities.resources.listChanged = false; + capabilities.supportsResourcesList = false; return capabilities; } diff --git a/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp b/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp index 393748cdc65591..9fa446133d46f7 100644 --- a/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp +++ b/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include "ProtocolMCPTestUtilities.h" #include "TestingSupport/Host/JSONTransportTestUtilities.h" #include "TestingSupport/Host/PipeTestUtilities.h" #include "TestingSupport/SubsystemRAII.h" @@ -20,6 +21,7 @@ #include "lldb/Protocol/MCP/Resource.h" #include "lldb/Protocol/MCP/Server.h" #include "lldb/Protocol/MCP/Tool.h" +#include "llvm/ADT/StringRef.h" #include "llvm/Support/Error.h" #include "llvm/Support/JSON.h" #include "llvm/Testing/Support/Error.h" @@ -27,13 +29,11 @@ #include "gtest/gtest.h" #include <chrono> #include <condition_variable> -#include <mutex> using namespace llvm; using namespace lldb; using namespace lldb_private; using namespace lldb_protocol::mcp; -using testing::_; namespace { class TestMCPTransport final : public MCPTransport { @@ -60,7 +60,7 @@ class TestTool : public Tool { public: using Tool::Tool; - llvm::Expected<TextResult> Call(const ToolArguments &args) override { + llvm::Expected<CallToolResult> Call(const ToolArguments &args) override { std::string argument; if (const json::Object *args_obj = std::get<json::Value>(args).getAsObject()) { @@ -69,7 +69,7 @@ class TestTool : public Tool { } } - TextResult text_result; + CallToolResult text_result; text_result.content.emplace_back(TextContent{{argument}}); return text_result; } @@ -91,17 +91,17 @@ class TestResourceProvider : public ResourceProvider { return resources; } - llvm::Expected<ResourceResult> + llvm::Expected<ReadResourceResult> ReadResource(llvm::StringRef uri) const override { if (uri != "lldb://foo/bar") return llvm::make_error<UnsupportedURI>(uri.str()); - ResourceContents contents; + TextResourceContents contents; contents.uri = "lldb://foo/bar"; contents.mimeType = "application/json"; contents.text = "foobar"; - ResourceResult result; + ReadResourceResult result; result.contents.push_back(contents); return result; } @@ -112,7 +112,7 @@ class ErrorTool : public Tool { public: using Tool::Tool; - llvm::Expected<TextResult> Call(const ToolArguments &args) override { + llvm::Expected<CallToolResult> Call(const ToolArguments &args) override { return llvm::createStringError("error"); } }; @@ -122,8 +122,8 @@ class FailTool : public Tool { public: using Tool::Tool; - llvm::Expected<TextResult> Call(const ToolArguments &args) override { - TextResult text_result; + llvm::Expected<CallToolResult> Call(const ToolArguments &args) override { + CallToolResult text_result; text_result.content.emplace_back(TextContent{{"failed"}}); text_result.isError = true; return text_result; @@ -146,6 +146,8 @@ class ProtocolServerMCPTest : public PipePairTest { return transport_up->Write(*value); } + llvm::Error Write(json::Value value) { return transport_up->Write(value); } + /// Run the transport MainLoop and return any messages received. llvm::Error Run(std::chrono::milliseconds timeout = std::chrono::milliseconds(200)) { @@ -182,37 +184,43 @@ class ProtocolServerMCPTest : public PipePairTest { } }; +template <typename T> +Request make_request(StringLiteral method, T &¶ms, Id id = 1) { + return Request{id, method.str(), toJSON(std::forward<T>(params))}; +} + +template <typename T> Response make_response(T &&result, Id id = 1) { + return Response{id, std::forward<T>(result)}; +} + } // namespace TEST_F(ProtocolServerMCPTest, Initialization) { - llvm::StringLiteral request = - R"json({"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"lldb-unit","version":"0.1.0"}},"jsonrpc":"2.0","id":1})json"; - llvm::StringLiteral response = - R"json({"id":1,"jsonrpc":"2.0","result":{"capabilities":{"resources":{"listChanged":false,"subscribe":false},"tools":{"listChanged":true}},"protocolVersion":"2024-11-05","serverInfo":{"name":"lldb-mcp","version":"0.1.0"}}})json"; + Request request = make_request( + "initialize", InitializeParams{/*protocolVersion=*/"2024-11-05", + /*capabilities=*/{}, + /*clientInfo=*/{"lldb-unit", "0.1.0"}}); + Response response = make_response( + InitializeResult{/*protocolVersion=*/"2024-11-05", + /*capabilities=*/{/*supportsToolsList=*/true}, + /*serverInfo=*/{"lldb-mcp", "0.1.0"}}); ASSERT_THAT_ERROR(Write(request), Succeeded()); - llvm::Expected<Response> expected_resp = json::parse<Response>(response); - ASSERT_THAT_EXPECTED(expected_resp, llvm::Succeeded()); - EXPECT_CALL(message_handler, Received(*expected_resp)); + EXPECT_CALL(message_handler, Received(response)); EXPECT_THAT_ERROR(Run(), Succeeded()); } TEST_F(ProtocolServerMCPTest, ToolsList) { server_up->AddTool(std::make_unique<TestTool>("test", "test tool")); - llvm::StringLiteral request = - R"json({"method":"tools/list","params":{},"jsonrpc":"2.0","id":"one"})json"; + Request request = make_request("tools/list", Void{}, /*id=*/"one"); ToolDefinition test_tool; test_tool.name = "test"; test_tool.description = "test tool"; test_tool.inputSchema = json::Object{{"type", "object"}}; - Response response; - response.id = "one"; - response.result = json::Object{ - {"tools", json::Array{std::move(test_tool)}}, - }; + Response response = make_response(ListToolsResult{{test_tool}}, /*id=*/"one"); ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); EXPECT_CALL(message_handler, Received(response)); @@ -222,60 +230,61 @@ TEST_F(ProtocolServerMCPTest, ToolsList) { TEST_F(ProtocolServerMCPTest, ResourcesList) { server_up->AddResourceProvider(std::make_unique<TestResourceProvider>()); - llvm::StringLiteral request = - R"json({"method":"resources/list","params":{},"jsonrpc":"2.0","id":2})json"; - llvm::StringLiteral response = - R"json({"id":2,"jsonrpc":"2.0","result":{"resources":[{"description":"description","mimeType":"application/json","name":"name","uri":"lldb://foo/bar"}]}})json"; + Request request = make_request("resources/list", Void{}); + Response response = make_response(ListResourcesResult{ + {{/*uri=*/"lldb://foo/bar", /*name=*/"name", + /*description=*/"description", /*mimeType=*/"application/json"}}}); ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - llvm::Expected<Response> expected_resp = json::parse<Response>(response); - ASSERT_THAT_EXPECTED(expected_resp, llvm::Succeeded()); - EXPECT_CALL(message_handler, Received(*expected_resp)); + EXPECT_CALL(message_handler, Received(response)); EXPECT_THAT_ERROR(Run(), Succeeded()); } TEST_F(ProtocolServerMCPTest, ToolsCall) { server_up->AddTool(std::make_unique<TestTool>("test", "test tool")); - llvm::StringLiteral request = - R"json({"method":"tools/call","params":{"name":"test","arguments":{"arguments":"foo","debugger_id":0}},"jsonrpc":"2.0","id":11})json"; - llvm::StringLiteral response = - R"json({"id":11,"jsonrpc":"2.0","result":{"content":[{"text":"foo","type":"text"}],"isError":false}})json"; + Request request = make_request( + "tools/call", CallToolParams{/*name=*/"test", /*arguments=*/json::Object{ + {"arguments", "foo"}, + {"debugger_id", 0}, + }}); + Response response = make_response(CallToolResult{{{/*text=*/"foo"}}}); ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - llvm::Expected<Response> expected_resp = json::parse<Response>(response); - ASSERT_THAT_EXPECTED(expected_resp, llvm::Succeeded()); - EXPECT_CALL(message_handler, Received(*expected_resp)); + EXPECT_CALL(message_handler, Received(response)); EXPECT_THAT_ERROR(Run(), Succeeded()); } TEST_F(ProtocolServerMCPTest, ToolsCallError) { server_up->AddTool(std::make_unique<ErrorTool>("error", "error tool")); - llvm::StringLiteral request = - R"json({"method":"tools/call","params":{"name":"error","arguments":{"arguments":"foo","debugger_id":0}},"jsonrpc":"2.0","id":11})json"; - llvm::StringLiteral response = - R"json({"error":{"code":-32603,"message":"error"},"id":11,"jsonrpc":"2.0"})json"; + Request request = make_request( + "tools/call", CallToolParams{/*name=*/"error", /*arguments=*/json::Object{ + {"arguments", "foo"}, + {"debugger_id", 0}, + }}); + Response response = + make_response(lldb_protocol::mcp::Error{eErrorCodeInternalError, + /*message=*/"error"}); ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - llvm::Expected<Response> expected_resp = json::parse<Response>(response); - ASSERT_THAT_EXPECTED(expected_resp, llvm::Succeeded()); - EXPECT_CALL(message_handler, Received(*expected_resp)); + EXPECT_CALL(message_handler, Received(response)); EXPECT_THAT_ERROR(Run(), Succeeded()); } TEST_F(ProtocolServerMCPTest, ToolsCallFail) { server_up->AddTool(std::make_unique<FailTool>("fail", "fail tool")); - llvm::StringLiteral request = - R"json({"method":"tools/call","params":{"name":"fail","arguments":{"arguments":"foo","debugger_id":0}},"jsonrpc":"2.0","id":11})json"; - llvm::StringLiteral response = - R"json({"id":11,"jsonrpc":"2.0","result":{"content":[{"text":"failed","type":"text"}],"isError":true}})json"; + Request request = make_request( + "tools/call", CallToolParams{/*name=*/"fail", /*arguments=*/json::Object{ + {"arguments", "foo"}, + {"debugger_id", 0}, + }}); + Response response = + make_response(CallToolResult{{{/*text=*/"failed"}}, /*isError=*/true}); ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - llvm::Expected<Response> expected_resp = json::parse<Response>(response); - ASSERT_THAT_EXPECTED(expected_resp, llvm::Succeeded()); - EXPECT_CALL(message_handler, Received(*expected_resp)); + EXPECT_CALL(message_handler, Received(response)); EXPECT_THAT_ERROR(Run(), Succeeded()); } diff --git a/lldb/unittests/Protocol/ProtocolMCPTest.cpp b/lldb/unittests/Protocol/ProtocolMCPTest.cpp index ea19922522ffe0..396e361e873fe3 100644 --- a/lldb/unittests/Protocol/ProtocolMCPTest.cpp +++ b/lldb/unittests/Protocol/ProtocolMCPTest.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include "ProtocolMCPTestUtilities.h" #include "TestingSupport/TestUtilities.h" #include "lldb/Protocol/MCP/Protocol.h" #include "llvm/Testing/Support/Error.h" @@ -54,31 +55,16 @@ TEST(ProtocolMCPTest, Notification) { EXPECT_EQ(notification.params, deserialized_notification->params); } -TEST(ProtocolMCPTest, ToolCapability) { - ToolCapability tool_capability; - tool_capability.listChanged = true; +TEST(ProtocolMCPTest, ServerCapabilities) { + ServerCapabilities capabilities; + capabilities.supportsToolsList = true; - llvm::Expected<ToolCapability> deserialized_tool_capability = - roundtripJSON(tool_capability); - ASSERT_THAT_EXPECTED(deserialized_tool_capability, llvm::Succeeded()); - - EXPECT_EQ(tool_capability.listChanged, - deserialized_tool_capability->listChanged); -} - -TEST(ProtocolMCPTest, Capabilities) { - ToolCapability tool_capability; - tool_capability.listChanged = true; - - Capabilities capabilities; - capabilities.tools = tool_capability; - - llvm::Expected<Capabilities> deserialized_capabilities = + llvm::Expected<ServerCapabilities> deserialized_capabilities = roundtripJSON(capabilities); ASSERT_THAT_EXPECTED(deserialized_capabilities, llvm::Succeeded()); - EXPECT_EQ(capabilities.tools.listChanged, - deserialized_capabilities->tools.listChanged); + EXPECT_EQ(capabilities.supportsToolsList, + deserialized_capabilities->supportsToolsList); } TEST(ProtocolMCPTest, TextContent) { @@ -92,18 +78,18 @@ TEST(ProtocolMCPTest, TextContent) { EXPECT_EQ(text_content.text, deserialized_text_content->text); } -TEST(ProtocolMCPTest, TextResult) { +TEST(ProtocolMCPTest, CallToolResult) { TextContent text_content1; text_content1.text = "Text 1"; TextContent text_content2; text_content2.text = "Text 2"; - TextResult text_result; + CallToolResult text_result; text_result.content = {text_content1, text_content2}; text_result.isError = true; - llvm::Expected<TextResult> deserialized_text_result = + llvm::Expected<CallToolResult> deserialized_text_result = roundtripJSON(text_result); ASSERT_THAT_EXPECTED(deserialized_text_result, llvm::Succeeded()); @@ -237,13 +223,13 @@ TEST(ProtocolMCPTest, ResourceWithoutOptionals) { EXPECT_TRUE(deserialized_resource->mimeType.empty()); } -TEST(ProtocolMCPTest, ResourceContents) { - ResourceContents contents; +TEST(ProtocolMCPTest, TextResourceContents) { + TextResourceContents contents; contents.uri = "resource://example/content"; contents.text = "This is the content of the resource"; contents.mimeType = "text/plain"; - llvm::Expected<ResourceContents> deserialized_contents = + llvm::Expected<TextResourceContents> deserialized_contents = roundtripJSON(contents); ASSERT_THAT_EXPECTED(deserialized_contents, llvm::Succeeded()); @@ -252,12 +238,12 @@ TEST(ProtocolMCPTest, ResourceContents) { EXPECT_EQ(contents.mimeType, deserialized_contents->mimeType); } -TEST(ProtocolMCPTest, ResourceContentsWithoutMimeType) { - ResourceContents contents; +TEST(ProtocolMCPTest, TextResourceContentsWithoutMimeType) { + TextResourceContents contents; contents.uri = "resource://example/content-no-mime"; contents.text = "Content without mime type specified"; - llvm::Expected<ResourceContents> deserialized_contents = + llvm::Expected<TextResourceContents> deserialized_contents = roundtripJSON(contents); ASSERT_THAT_EXPECTED(deserialized_contents, llvm::Succeeded()); @@ -266,21 +252,22 @@ TEST(ProtocolMCPTest, ResourceContentsWithoutMimeType) { EXPECT_TRUE(deserialized_contents->mimeType.empty()); } -TEST(ProtocolMCPTest, ResourceResult) { - ResourceContents contents1; +TEST(ProtocolMCPTest, ReadResourceResult) { + TextResourceContents contents1; contents1.uri = "resource://example/content1"; contents1.text = "First resource content"; contents1.mimeType = "text/plain"; - ResourceContents contents2; + TextResourceContents contents2; contents2.uri = "resource://example/content2"; contents2.text = "Second resource content"; contents2.mimeType = "application/json"; - ResourceResult result; + ReadResourceResult result; result.contents = {contents1, contents2}; - llvm::Expected<ResourceResult> deserialized_result = roundtripJSON(result); + llvm::Expected<ReadResourceResult> deserialized_result = + roundtripJSON(result); ASSERT_THAT_EXPECTED(deserialized_result, llvm::Succeeded()); ASSERT_EQ(result.contents.size(), deserialized_result->contents.size()); @@ -296,10 +283,11 @@ TEST(ProtocolMCPTest, ResourceResult) { deserialized_result->contents[1].mimeType); } -TEST(ProtocolMCPTest, ResourceResultEmpty) { - ResourceResult result; +TEST(ProtocolMCPTest, ReadResourceResultEmpty) { + ReadResourceResult result; - llvm::Expected<ResourceResult> deserialized_result = roundtripJSON(result); + llvm::Expected<ReadResourceResult> deserialized_result = + roundtripJSON(result); ASSERT_THAT_EXPECTED(deserialized_result, llvm::Succeeded()); EXPECT_TRUE(deserialized_result->contents.empty()); diff --git a/lldb/unittests/Protocol/ProtocolMCPTestUtilities.h b/lldb/unittests/Protocol/ProtocolMCPTestUtilities.h new file mode 100644 index 00000000000000..f8a14f4be03c91 --- /dev/null +++ b/lldb/unittests/Protocol/ProtocolMCPTestUtilities.h @@ -0,0 +1,40 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLDB_UNITTESTS_PROTOCOL_PROTOCOLMCPTESTUTILITIES_H +#define LLDB_UNITTESTS_PROTOCOL_PROTOCOLMCPTESTUTILITIES_H + +#include "lldb/Protocol/MCP/Protocol.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/JSON.h" // IWYU pragma: keep +#include "gtest/gtest.h" // IWYU pragma: keep +#include <ostream> +#include <variant> + +namespace lldb_protocol::mcp { + +inline void PrintTo(const Request &req, std::ostream *os) { + *os << llvm::formatv("{0}", toJSON(req)).str(); +} + +inline void PrintTo(const Response &resp, std::ostream *os) { + *os << llvm::formatv("{0}", toJSON(resp)).str(); +} + +inline void PrintTo(const Notification ¬e, std::ostream *os) { + *os << llvm::formatv("{0}", toJSON(note)).str(); +} + +inline void PrintTo(const Message &message, std::ostream *os) { + return std::visit([os](auto &&message) { return PrintTo(message, os); }, + message); +} + +} // namespace lldb_protocol::mcp + +#endif _______________________________________________ lldb-commits mailing list lldb-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/lldb-commits