common/Session.cpp | 50 +++++++++++------- common/Session.hpp | 38 +++++++++++-- kit/ChildSession.cpp | 14 ++--- kit/ChildSession.hpp | 22 +++++--- kit/Kit.cpp | 12 ++-- net/Socket.cpp | 8 +- net/Socket.hpp | 109 ++++++++++++++++++++++++++++++++++------ net/SslSocket.hpp | 2 net/WebSocketHandler.hpp | 63 +++++++++++++++++++---- test/UnitWOPIVersionRestore.cpp | 1 tools/WebSocketDump.cpp | 2 wsd/ClientSession.cpp | 44 +++++++--------- wsd/ClientSession.hpp | 21 ++++--- wsd/DocumentBroker.cpp | 20 ++++--- wsd/DocumentBroker.hpp | 11 ++-- wsd/LOOLWSD.cpp | 53 +++++++++++-------- wsd/TestStubs.cpp | 8 +- 17 files changed, 332 insertions(+), 146 deletions(-)
New commits: commit e924625cc1af8736505f363fc525d20a6373bb95 Author: Michael Meeks <michael.me...@collabora.com> AuthorDate: Fri Mar 6 17:43:46 2020 +0000 Commit: Michael Meeks <michael.me...@collabora.com> CommitDate: Wed Mar 11 16:48:03 2020 +0100 re-factor: Socket / WebSocketHandler. Essentially we want to be able to separate low-level socket code for eg. TCP vs. UDS, from Protocol handling: eg. WebSocketHandler and client sessions themselves which handle and send messages which now implement the simple MessageHandlerInterface. Some helpful renaming too: s/SocketHandlerInterface/ProtocolHandlerInterface/ Change-Id: I58092b5e0b5792fda47498fb2c875851eada461d Reviewed-on: https://gerrit.libreoffice.org/c/online/+/90138 Tested-by: Jenkins CollaboraOffice <jenkinscollaboraoff...@gmail.com> Reviewed-by: Michael Meeks <michael.me...@collabora.com> diff --git a/common/Session.cpp b/common/Session.cpp index 4b4c563d6..15dbe86d7 100644 --- a/common/Session.cpp +++ b/common/Session.cpp @@ -44,7 +44,9 @@ using namespace LOOLProtocol; using Poco::Exception; using std::size_t; -Session::Session(const std::string& name, const std::string& id, bool readOnly) : +Session::Session(const std::shared_ptr<ProtocolHandlerInterface> &protocol, + const std::string& name, const std::string& id, bool readOnly) : + MessageHandlerInterface(protocol), _id(id), _name(name), _disconnected(false), @@ -65,14 +67,26 @@ Session::~Session() bool Session::sendTextFrame(const char* buffer, const int length) { + if (!_protocol) + { + LOG_TRC("ERR - missing protocol " << getName() << ": Send: [" << getAbbreviatedMessage(buffer, length) << "]."); + return false; + } + LOG_TRC(getName() << ": Send: [" << getAbbreviatedMessage(buffer, length) << "]."); - return sendMessage(buffer, length, WSOpCode::Text) >= length; + return _protocol->sendTextMessage(buffer, length) >= length; } bool Session::sendBinaryFrame(const char *buffer, int length) { + if (!_protocol) + { + LOG_TRC("ERR - missing protocol " << getName() << ": Send: " << std::to_string(length) << " binary bytes."); + return false; + } + LOG_TRC(getName() << ": Send: " << std::to_string(length) << " binary bytes."); - return sendMessage(buffer, length, WSOpCode::Binary) >= length; + return _protocol->sendBinaryMessage(buffer, length) >= length; } void Session::parseDocOptions(const StringVector& tokens, int& part, std::string& timestamp, std::string& doctemplate) @@ -196,15 +210,20 @@ void Session::disconnect() } } -void Session::shutdown(const WebSocketHandler::StatusCodes statusCode, const std::string& statusMessage) +void Session::shutdown(bool goingAway, const std::string& statusMessage) { - LOG_TRC("Shutting down WS [" << getName() << "] with statusCode [" << - static_cast<unsigned>(statusCode) << "] and reason [" << statusMessage << "]."); + LOG_TRC("Shutting down WS [" << getName() << "] " << + (goingAway ? "going" : "normal") << + " and reason [" << statusMessage << "]."); // See protocol.txt for this application-level close frame. - sendMessage("close: " + statusMessage); - - WebSocketHandler::shutdown(statusCode, statusMessage); + if (_protocol) + { + // skip the queue; FIXME: should we flush SessionClient's queue ? + std::string closeMsg = "close: " + statusMessage; + _protocol->sendTextMessage(closeMsg, closeMsg.size()); + _protocol->shutdown(goingAway, statusMessage); + } } void Session::handleMessage(const std::vector<char> &data) @@ -238,21 +257,12 @@ void Session::handleMessage(const std::vector<char> &data) void Session::getIOStats(uint64_t &sent, uint64_t &recv) { - std::shared_ptr<StreamSocket> socket = getSocket().lock(); - if (socket) - socket->getIOStats(sent, recv); - else - { - sent = 0; - recv = 0; - } + _protocol->getIOStats(sent, recv); } void Session::dumpState(std::ostream& os) { - WebSocketHandler::dumpState(os); - - os << "\t\tid: " << _id + os << "\t\tid: " << _id << "\n\t\tname: " << _name << "\n\t\tdisconnected: " << _disconnected << "\n\t\tisActive: " << _isActive diff --git a/common/Session.hpp b/common/Session.hpp index 6b5e93322..dbf75ad2f 100644 --- a/common/Session.hpp +++ b/common/Session.hpp @@ -64,7 +64,7 @@ public: }; /// Base class of a WebSocket session. -class Session : public WebSocketHandler +class Session : public MessageHandlerInterface { public: const std::string& getId() const { return _id; } @@ -74,8 +74,32 @@ public: virtual void setReadOnly() { _isReadOnly = true; } bool isReadOnly() const { return _isReadOnly; } + /// overridden to prepend client ids on messages by the Kit virtual bool sendBinaryFrame(const char* buffer, int length); virtual bool sendTextFrame(const char* buffer, const int length); + + /// Get notified that the underlying transports disconnected + void onDisconnect() override { /* ignore */ } + + bool hasQueuedMessages() const override + { + // queued in Socket output buffer + return false; + } + + // By default rely on the socket buffer. + void writeQueuedMessages() override + { + assert(false); + } + + /// Sends a WebSocket Text message. + int sendMessage(const std::string& msg) + { + return sendTextFrame(msg.data(), msg.size()); + } + + // FIXME: remove synonym - and clean from WebSocketHandler too ... (?) bool sendTextFrame(const std::string& text) { return sendTextFrame(text.data(), text.size()); @@ -98,12 +122,10 @@ public: virtual void disconnect(); /// clean & normal shutdown - void shutdownNormal(const std::string& statusMessage = "") - { shutdown(WebSocketHandler::StatusCodes::NORMAL_CLOSE, statusMessage); } + void shutdownNormal(const std::string& statusMessage = "") { shutdown(false, statusMessage); } /// abnormal / hash shutdown end-point going away - void shutdownGoingAway(const std::string& statusMessage = "") - { shutdown(WebSocketHandler::StatusCodes::ENDPOINT_GOING_AWAY, statusMessage); } + void shutdownGoingAway(const std::string& statusMessage = "") { shutdown(true, statusMessage); } bool isActive() const { return _isActive; } void setIsActive(bool active) { _isActive = active; } @@ -165,7 +187,8 @@ public: } protected: - Session(const std::string& name, const std::string& id, bool readonly); + Session(const std::shared_ptr<ProtocolHandlerInterface> &handler, + const std::string& name, const std::string& id, bool readonly); virtual ~Session(); /// Parses the options of the "load" command, @@ -181,8 +204,7 @@ protected: private: - void shutdown(const WebSocketHandler::StatusCodes statusCode = WebSocketHandler::StatusCodes::NORMAL_CLOSE, - const std::string& statusMessage = ""); + void shutdown(bool goingAway = false, const std::string& statusMessage = ""); virtual bool _handleInput(const char* buffer, int length) = 0; diff --git a/kit/ChildSession.cpp b/kit/ChildSession.cpp index 682012a50..4842b8ffe 100644 --- a/kit/ChildSession.cpp +++ b/kit/ChildSession.cpp @@ -19,7 +19,6 @@ #include <Poco/JSON/Object.h> #include <Poco/JSON/Parser.h> -#include <Poco/Net/WebSocket.h> #include <Poco/StreamCopier.h> #include <Poco/URI.h> #include <Poco/BinaryReader.h> @@ -62,10 +61,12 @@ std::vector<unsigned char> decodeBase64(const std::string & inputBase64) } -ChildSession::ChildSession(const std::string& id, - const std::string& jailId, - DocumentManagerInterface& docManager) : - Session("ToMaster-" + id, id, false), +ChildSession::ChildSession( + const std::shared_ptr<ProtocolHandlerInterface> &protocol, + const std::string& id, + const std::string& jailId, + DocumentManagerInterface& docManager) : + Session(protocol, "ToMaster-" + id, id, false), _jailId(jailId), _docManager(&docManager), _viewId(-1), @@ -98,7 +99,8 @@ void ChildSession::disconnect() LOG_WRN("Skipping unload on incomplete view."); } - Session::disconnect(); +// This shuts down the shared socket, which is not what we want. +// Session::disconnect(); } } diff --git a/kit/ChildSession.hpp b/kit/ChildSession.hpp index 9bb2b7d0f..c7a248546 100644 --- a/kit/ChildSession.hpp +++ b/kit/ChildSession.hpp @@ -199,9 +199,11 @@ public: /// a new view) or nullptr (when first view). /// jailId The JailID of the jail root directory, // used by downloadas to construct jailed path. - ChildSession(const std::string& id, - const std::string& jailId, - DocumentManagerInterface& docManager); + ChildSession( + const std::shared_ptr<ProtocolHandlerInterface> &protocol, + const std::string& id, + const std::string& jailId, + DocumentManagerInterface& docManager); virtual ~ChildSession(); bool getStatus(const char* buffer, int length); @@ -219,12 +221,22 @@ public: bool sendTextFrame(const char* buffer, int length) override { + if (!_docManager) + { + LOG_TRC("ERR dropping - client-" + getId() + ' ' + std::string(buffer, length)); + return false; + } const auto msg = "client-" + getId() + ' ' + std::string(buffer, length); return _docManager->sendFrame(msg.data(), msg.size(), WSOpCode::Text); } bool sendBinaryFrame(const char* buffer, int length) override { + if (!_docManager) + { + LOG_TRC("ERR dropping binary - client-" + getId()); + return false; + } const auto msg = "client-" + getId() + ' ' + std::string(buffer, length); return _docManager->sendFrame(msg.data(), msg.size(), WSOpCode::Binary); } @@ -235,11 +247,7 @@ public: void resetDocManager() { -#if MOBILEAPP - // I suspect this might be useful even for the non-mobile case, but - // not 100% sure, so rather do it mobile-only for now disconnect(); -#endif _docManager = nullptr; } diff --git a/kit/Kit.cpp b/kit/Kit.cpp index a302f6e35..54ad81647 100644 --- a/kit/Kit.cpp +++ b/kit/Kit.cpp @@ -781,7 +781,9 @@ public: " session for url: " << anonymizeUrl(_url) << " for sessionId: " << sessionId << " on jailId: " << _jailId); - auto session = std::make_shared<ChildSession>(sessionId, _jailId, *this); + auto session = std::make_shared<ChildSession>( + _websocketHandler, + sessionId, _jailId, *this); _sessions.emplace(sessionId, session); int viewId = session->getViewId(); @@ -2072,7 +2074,7 @@ std::shared_ptr<lok::Document> getLOKDocument() return Document::_loKitDocument; } -class KitWebSocketHandler final : public WebSocketHandler, public std::enable_shared_from_this<KitWebSocketHandler> +class KitWebSocketHandler final : public WebSocketHandler { std::shared_ptr<TileQueue> _queue; std::string _socketName; @@ -2137,7 +2139,9 @@ protected: Util::setThreadName("kitbroker_" + docId); if (!document) - document = std::make_shared<Document>(_loKit, _jailId, docKey, docId, url, _queue, shared_from_this()); + document = std::make_shared<Document>( + _loKit, _jailId, docKey, docId, url, _queue, + std::static_pointer_cast<WebSocketHandler>(shared_from_this())); // Validate and create session. if (!(url == document->getUrl() && document->createSession(sessionId))) @@ -2633,7 +2637,7 @@ void lokit_main( KitSocketPoll mainKit; mainKit.runOnClientThread(); // We will do the polling on this thread. - std::shared_ptr<SocketHandlerInterface> websocketHandler = + std::shared_ptr<ProtocolHandlerInterface> websocketHandler = std::make_shared<KitWebSocketHandler>("child_ws", loKit, jailId); #if !MOBILEAPP mainKit.insertNewUnixSocket(MasterLocation, pathAndQuery, websocketHandler); diff --git a/net/Socket.cpp b/net/Socket.cpp index cbe5a0a52..5bb1fa250 100644 --- a/net/Socket.cpp +++ b/net/Socket.cpp @@ -204,7 +204,7 @@ void SocketPoll::wakeupWorld() void SocketPoll::insertNewWebSocketSync( const Poco::URI &uri, - const std::shared_ptr<SocketHandlerInterface>& websocketHandler) + const std::shared_ptr<ProtocolHandlerInterface>& websocketHandler) { LOG_INF("Connecting to " << uri.getHost() << " : " << uri.getPort() << " : " << uri.getPath()); @@ -277,7 +277,7 @@ void SocketPoll::insertNewWebSocketSync( // should this be a static method in the WebsocketHandler(?) void SocketPoll::clientRequestWebsocketUpgrade(const std::shared_ptr<StreamSocket>& socket, - const std::shared_ptr<SocketHandlerInterface>& websocketHandler, + const std::shared_ptr<ProtocolHandlerInterface>& websocketHandler, const std::string &pathAndQuery) { // cf. WebSocketHandler::upgradeToWebSocket (?) @@ -304,7 +304,7 @@ void SocketPoll::clientRequestWebsocketUpgrade(const std::shared_ptr<StreamSocke void SocketPoll::insertNewUnixSocket( const std::string &location, const std::string &pathAndQuery, - const std::shared_ptr<SocketHandlerInterface>& websocketHandler) + const std::shared_ptr<ProtocolHandlerInterface>& websocketHandler) { int fd = socket(AF_UNIX, SOCK_STREAM | SOCK_NONBLOCK, 0); @@ -337,7 +337,7 @@ void SocketPoll::insertNewUnixSocket( void SocketPoll::insertNewFakeSocket( int peerSocket, - const std::shared_ptr<SocketHandlerInterface>& websocketHandler) + const std::shared_ptr<ProtocolHandlerInterface>& websocketHandler) { LOG_INF("Connecting to " << peerSocket); int fd = fakeSocketSocket(); diff --git a/net/Socket.hpp b/net/Socket.hpp index c95b93dd7..99fdf259a 100644 --- a/net/Socket.hpp +++ b/net/Socket.hpp @@ -344,12 +344,21 @@ private: }; class StreamSocket; +class MessageHandlerInterface; -/// Interface that handles the actual incoming message. -class SocketHandlerInterface +/// Interface that decodes the actual incoming message. +class ProtocolHandlerInterface : + public std::enable_shared_from_this<ProtocolHandlerInterface> { +protected: + /// We own a message handler, after decoding the socket data we pass it on as messages. + std::shared_ptr<MessageHandlerInterface> _msgHandler; public: - virtual ~SocketHandlerInterface() {} + // ------------------------------------------------------------------ + // Interface for implementing low level socket goodness from streams. + // ------------------------------------------------------------------ + virtual ~ProtocolHandlerInterface() { } + /// Called when the socket is newly created to /// set the socket associated with this ResponseClient. /// Will be called exactly once. @@ -374,10 +383,81 @@ public: /// Will be called exactly once. virtual void onDisconnect() {} + // ----------------------------------------------------------------- + // Interface for external MessageHandlers + // ----------------------------------------------------------------- +public: + void setMessageHandler(const std::shared_ptr<MessageHandlerInterface> &msgHandler) + { + _msgHandler = msgHandler; + } + + /// Clear all external references + virtual void dispose() { _msgHandler.reset(); } + + virtual int sendTextMessage(const std::string &msg, const size_t len, bool flush = false) const = 0; + virtual int sendBinaryMessage(const char *data, const size_t len, bool flush = false) const = 0; + virtual void shutdown(bool goingAway = false, const std::string &statusMessage = "") = 0; + + virtual void getIOStats(uint64_t &sent, uint64_t &recv) = 0; + /// Append pretty printed internal state to a line virtual void dumpState(std::ostream& os) { os << "\n"; } }; +/// A ProtocolHandlerInterface with dummy sending API. +class SimpleSocketHandler : public ProtocolHandlerInterface +{ +public: + SimpleSocketHandler() {} + int sendTextMessage(const std::string &, const size_t, bool) const override { return 0; } + int sendBinaryMessage(const char *, const size_t , bool ) const override { return 0; } + void shutdown(bool, const std::string &) override {} + void getIOStats(uint64_t &, uint64_t &) override {} +}; + +/// Interface that receives and sends incoming messages. +class MessageHandlerInterface : + public std::enable_shared_from_this<MessageHandlerInterface> +{ +protected: + std::shared_ptr<ProtocolHandlerInterface> _protocol; + MessageHandlerInterface(const std::shared_ptr<ProtocolHandlerInterface> &protocol) : + _protocol(protocol) + { + } + virtual ~MessageHandlerInterface() {} + +public: + /// Setup, after construction for shared_from_this + void initialize() + { + if (_protocol) + _protocol->setMessageHandler(shared_from_this()); + } + + /// Clear all external references + virtual void dispose() + { + if (_protocol) + { + _protocol->dispose(); + _protocol.reset(); + } + } + + /// Do we have something to send ? + virtual bool hasQueuedMessages() const = 0; + /// Please send them to me then. + virtual void writeQueuedMessages() = 0; + /// We just got a message - here it is + virtual void handleMessage(const std::vector<char> &data) = 0; + /// Get notified that the underlying transports disconnected + virtual void onDisconnect() = 0; + /// Append pretty printed internal state to a line + virtual void dumpState(std::ostream& os) = 0; +}; + /// Handles non-blocking socket event polling. /// Only polls on N-Sockets and invokes callback and /// doesn't manage buffers or client data. @@ -672,16 +752,16 @@ public: /// Inserts a new remote websocket to be polled. /// NOTE: The DNS lookup is synchronous. void insertNewWebSocketSync(const Poco::URI &uri, - const std::shared_ptr<SocketHandlerInterface>& websocketHandler); + const std::shared_ptr<ProtocolHandlerInterface>& websocketHandler); void insertNewUnixSocket( const std::string &location, const std::string &pathAndQuery, - const std::shared_ptr<SocketHandlerInterface>& websocketHandler); + const std::shared_ptr<ProtocolHandlerInterface>& websocketHandler); #else void insertNewFakeSocket( int peerSocket, - const std::shared_ptr<SocketHandlerInterface>& websocketHandler); + const std::shared_ptr<ProtocolHandlerInterface>& websocketHandler); #endif typedef std::function<void()> CallbackFn; @@ -736,7 +816,7 @@ protected: private: /// Generate the request to connect & upgrade this socket to a given path void clientRequestWebsocketUpgrade(const std::shared_ptr<StreamSocket>& socket, - const std::shared_ptr<SocketHandlerInterface>& websocketHandler, + const std::shared_ptr<ProtocolHandlerInterface>& websocketHandler, const std::string &pathAndQuery); /// Initialize the poll fds array with the right events @@ -791,12 +871,13 @@ private: }; /// A plain, non-blocking, data streaming socket. -class StreamSocket : public Socket, public std::enable_shared_from_this<StreamSocket> +class StreamSocket : public Socket, + public std::enable_shared_from_this<StreamSocket> { public: /// Create a StreamSocket from native FD. StreamSocket(const int fd, bool /* isClient */, - std::shared_ptr<SocketHandlerInterface> socketHandler) : + std::shared_ptr<ProtocolHandlerInterface> socketHandler) : Socket(fd), _socketHandler(std::move(socketHandler)), _bytesSent(0), @@ -933,7 +1014,7 @@ public: } /// Replace the existing SocketHandler with a new one. - void setHandler(std::shared_ptr<SocketHandlerInterface> handler) + void setHandler(std::shared_ptr<ProtocolHandlerInterface> handler) { _socketHandler = std::move(handler); _socketHandler->onConnect(shared_from_this()); @@ -944,9 +1025,9 @@ public: /// but we can't have a shared_ptr in the ctor. template <typename TSocket> static - std::shared_ptr<TSocket> create(const int fd, bool isClient, std::shared_ptr<SocketHandlerInterface> handler) + std::shared_ptr<TSocket> create(const int fd, bool isClient, std::shared_ptr<ProtocolHandlerInterface> handler) { - SocketHandlerInterface* pHandler = handler.get(); + ProtocolHandlerInterface* pHandler = handler.get(); auto socket = std::make_shared<TSocket>(fd, isClient, std::move(handler)); pHandler->onConnect(socket); return socket; @@ -1157,14 +1238,14 @@ protected: return _shutdownSignalled; } - const std::shared_ptr<SocketHandlerInterface>& getSocketHandler() const + const std::shared_ptr<ProtocolHandlerInterface>& getSocketHandler() const { return _socketHandler; } private: /// Client handling the actual data. - std::shared_ptr<SocketHandlerInterface> _socketHandler; + std::shared_ptr<ProtocolHandlerInterface> _socketHandler; std::vector<char> _inBuffer; std::vector<char> _outBuffer; diff --git a/net/SslSocket.hpp b/net/SslSocket.hpp index ba9954f56..27e075328 100644 --- a/net/SslSocket.hpp +++ b/net/SslSocket.hpp @@ -20,7 +20,7 @@ class SslStreamSocket final : public StreamSocket { public: SslStreamSocket(const int fd, bool isClient, - std::shared_ptr<SocketHandlerInterface> responseClient) : + std::shared_ptr<ProtocolHandlerInterface> responseClient) : StreamSocket(fd, isClient, std::move(responseClient)), _bio(nullptr), _ssl(nullptr), diff --git a/net/WebSocketHandler.hpp b/net/WebSocketHandler.hpp index 130f81b69..1c2977602 100644 --- a/net/WebSocketHandler.hpp +++ b/net/WebSocketHandler.hpp @@ -24,7 +24,7 @@ #include <Poco/Net/HTTPResponse.h> #include <Poco/Net/WebSocket.h> -class WebSocketHandler : public SocketHandlerInterface +class WebSocketHandler : public ProtocolHandlerInterface { private: /// The socket that owns us (we can't own it). @@ -94,7 +94,7 @@ public: upgradeToWebSocket(request); } - /// Implementation of the SocketHandlerInterface. + /// Implementation of the ProtocolHandlerInterface. void onConnect(const std::shared_ptr<StreamSocket>& socket) override { _socket = socket; @@ -146,6 +146,24 @@ public: #endif } + void shutdown(bool goingAway, const std::string &statusMessage) override + { + shutdown(goingAway ? WebSocketHandler::StatusCodes::ENDPOINT_GOING_AWAY : + WebSocketHandler::StatusCodes::NORMAL_CLOSE, statusMessage); + } + + void getIOStats(uint64_t &sent, uint64_t &recv) override + { + std::shared_ptr<StreamSocket> socket = getSocket().lock(); + if (socket) + socket->getIOStats(sent, recv); + else + { + sent = 0; + recv = 0; + } + } + void shutdown(const StatusCodes statusCode = StatusCodes::NORMAL_CLOSE, const std::string& statusMessage = "") { if (!_shuttingDown) @@ -384,7 +402,7 @@ public: return true; } - /// Implementation of the SocketHandlerInterface. + /// Implementation of the ProtocolHandlerInterface. virtual void handleIncomingMessage(SocketDisposition&) override { // LOG_TRC("***** WebSocketHandler::handleIncomingMessage()"); @@ -421,7 +439,10 @@ public: std::chrono::duration_cast<std::chrono::milliseconds>(now - _lastPingSentTime).count(); timeoutMaxMs = std::min(timeoutMaxMs, PingFrequencyMs - timeSincePingMs); } - return POLLIN; + int events = POLLIN; + if (_msgHandler && _msgHandler->hasQueuedMessages()) + events |= POLLOUT; + return events; } #if !MOBILEAPP @@ -483,13 +504,34 @@ private: #endif } public: - /// By default rely on the socket buffer. - void performWrites() override {} + void performWrites() override + { + if (_msgHandler) + _msgHandler->writeQueuedMessages(); + } + + void onDisconnect() override + { + if (_msgHandler) + _msgHandler->onDisconnect(); + } /// Sends a WebSocket Text message. int sendMessage(const std::string& msg) const { - return sendMessage(msg.data(), msg.size(), WSOpCode::Text); + return sendTextMessage(msg, msg.size()); + } + + /// Implementation of the ProtocolHandlerInterface. + int sendTextMessage(const std::string &msg, const size_t len, bool flush = false) const override + { + return sendMessage(msg.data(), len, WSOpCode::Text, flush); + } + + /// Implementation of the ProtocolHandlerInterface. + int sendBinaryMessage(const char *data, const size_t len, bool flush = false) const override + { + return sendMessage(data, len, WSOpCode::Binary, flush); } /// Sends a WebSocket message of WPOpCode type. @@ -506,9 +548,7 @@ public: std::shared_ptr<StreamSocket> socket = _socket.lock(); return sendFrame(socket, data, len, WSFrameMask::Fin | static_cast<unsigned char>(code), flush); } - private: - /// Sends a WebSocket frame given the data, length, and flags. /// Returns the number of bytes written (including frame overhead) on success, /// 0 for closed/invalid socket, and -1 for other errors. @@ -615,8 +655,10 @@ protected: } /// To be overriden to handle the websocket messages the way you need. - virtual void handleMessage(const std::vector<char> &/*data*/) + virtual void handleMessage(const std::vector<char> &data) { + if (_msgHandler) + _msgHandler->handleMessage(data); } std::weak_ptr<StreamSocket>& getSocket() @@ -629,6 +671,7 @@ protected: _socket = socket; } + /// Implementation of the ProtocolHandlerInterface. void dumpState(std::ostream& os) override; private: diff --git a/test/UnitWOPIVersionRestore.cpp b/test/UnitWOPIVersionRestore.cpp index 3ad8dab09..16192c621 100644 --- a/test/UnitWOPIVersionRestore.cpp +++ b/test/UnitWOPIVersionRestore.cpp @@ -68,6 +68,7 @@ public: { constexpr char testName[] = "UnitWOPIVersionRestore"; + LOG_TRC("invokeTest " << (int)_phase); switch (_phase) { case Phase::Load: diff --git a/tools/WebSocketDump.cpp b/tools/WebSocketDump.cpp index e2fe32e54..c699a8fed 100644 --- a/tools/WebSocketDump.cpp +++ b/tools/WebSocketDump.cpp @@ -50,7 +50,7 @@ private: }; /// Handles incoming connections and dispatches to the appropriate handler. -class ClientRequestDispatcher : public SocketHandlerInterface +class ClientRequestDispatcher : public SimpleSocketHandler { public: ClientRequestDispatcher() diff --git a/wsd/ClientSession.cpp b/wsd/ClientSession.cpp index 696411fbf..29e420dad 100644 --- a/wsd/ClientSession.cpp +++ b/wsd/ClientSession.cpp @@ -38,12 +38,14 @@ using Poco::Path; static std::mutex GlobalSessionMapMutex; static std::unordered_map<std::string, std::weak_ptr<ClientSession>> GlobalSessionMap; -ClientSession::ClientSession(const std::string& id, - const std::shared_ptr<DocumentBroker>& docBroker, - const Poco::URI& uriPublic, - const bool readOnly, - const std::string& hostNoTrust) : - Session("ToClient-" + id, id, readOnly), +ClientSession::ClientSession( + const std::shared_ptr<ProtocolHandlerInterface>& ws, + const std::string& id, + const std::shared_ptr<DocumentBroker>& docBroker, + const Poco::URI& uriPublic, + const bool readOnly, + const std::string& hostNoTrust) : + Session(ws, "ToClient-" + id, id, readOnly), _docBroker(docBroker), _uriPublic(uriPublic), _isDocumentOwner(false), @@ -86,7 +88,8 @@ ClientSession::ClientSession(const std::string& id, void ClientSession::construct() { std::unique_lock<std::mutex> lock(GlobalSessionMapMutex); - GlobalSessionMap[getId()] = shared_from_this(); + MessageHandlerInterface::initialize(); + GlobalSessionMap[getId()] = client_from_this(); } ClientSession::~ClientSession() @@ -444,7 +447,7 @@ bool ClientSession::_handleInput(const char *buffer, int length) } else if (tokens.equals(0, "canceltiles")) { - docBroker->cancelTileRequests(shared_from_this()); + docBroker->cancelTileRequests(client_from_this()); return true; } else if (tokens.equals(0, "commandvalues")) @@ -678,7 +681,7 @@ bool ClientSession::_handleInput(const char *buffer, int length) else LOG_INF("Tileprocessed message with an unknown tile ID"); - docBroker->sendRequestedTiles(shared_from_this()); + docBroker->sendRequestedTiles(client_from_this()); return true; } else if (tokens.equals(0, "removesession")) { @@ -882,7 +885,7 @@ bool ClientSession::sendTile(const char * /*buffer*/, int /*length*/, const Stri { TileDesc tileDesc = TileDesc::parse(tokens); tileDesc.setNormalizedViewId(getCanonicalViewId()); - docBroker->handleTileRequest(tileDesc, shared_from_this()); + docBroker->handleTileRequest(tileDesc, client_from_this()); } catch (const std::exception& exc) { @@ -900,7 +903,7 @@ bool ClientSession::sendCombinedTiles(const char* /*buffer*/, int /*length*/, co { TileCombined tileCombined = TileCombined::parse(tokens); tileCombined.setNormalizedViewId(getCanonicalViewId()); - docBroker->handleTileCombinedRequest(tileCombined, shared_from_this()); + docBroker->handleTileCombinedRequest(tileCombined, client_from_this()); } catch (const std::exception& exc) { @@ -981,17 +984,13 @@ void ClientSession::setReadOnly() sendTextFrame("perm: readonly"); } -int ClientSession::getPollEvents(std::chrono::steady_clock::time_point /* now */, - int & /* timeoutMaxMs */) +bool ClientSession::hasQueuedMessages() const { - LOG_TRC(getName() << " ClientSession has " << _senderQueue.size() << " write message(s) queued."); - int events = POLLIN; - if (_senderQueue.size()) - events |= POLLOUT; - return events; + return _senderQueue.size() > 0; } -void ClientSession::performWrites() + /// Please send them to me then. +void ClientSession::writeQueuedMessages() { LOG_TRC(getName() << " ClientSession: performing writes."); @@ -1706,11 +1705,10 @@ void ClientSession::dumpState(std::ostream& os) << "\n\t\tclipboardKeys[1]: " << _clipboardKeys[1] << "\n\t\tclip sockets: " << _clipSockets.size(); - std::shared_ptr<StreamSocket> socket = getSocket().lock(); - if (socket) + if (_protocol) { uint64_t sent, recv; - socket->getIOStats(sent, recv); + _protocol->getIOStats(sent, recv); os << "\n\t\tsent/keystroke: " << (double)sent/_keyEvents << "bytes"; } @@ -1781,7 +1779,7 @@ void ClientSession::handleTileInvalidation(const std::string& message, { TileCombined tileCombined = TileCombined::create(invalidTiles); tileCombined.setNormalizedViewId(normalizedViewId); - docBroker->handleTileCombinedRequest(tileCombined, shared_from_this()); + docBroker->handleTileCombinedRequest(tileCombined, client_from_this()); } } diff --git a/wsd/ClientSession.hpp b/wsd/ClientSession.hpp index fe39b9e7d..b47285dd0 100644 --- a/wsd/ClientSession.hpp +++ b/wsd/ClientSession.hpp @@ -24,12 +24,12 @@ class DocumentBroker; - /// Represents a session to a LOOL client, in the WSD process. -class ClientSession final : public Session, public std::enable_shared_from_this<ClientSession> +class ClientSession final : public Session { public: - ClientSession(const std::string& id, + ClientSession(const std::shared_ptr<ProtocolHandlerInterface>& ws, + const std::string& id, const std::shared_ptr<DocumentBroker>& docBroker, const Poco::URI& uriPublic, const bool isReadOnly, @@ -174,14 +174,19 @@ public: void rotateClipboardKey(bool notifyClient); private: + std::shared_ptr<ClientSession> client_from_this() + { + return std::static_pointer_cast<ClientSession>(shared_from_this()); + } + /// SocketHandler: disconnection event. void onDisconnect() override; - /// Does SocketHandler: have data or timeouts to setup. - int getPollEvents(std::chrono::steady_clock::time_point /* now */, - int & /* timeoutMaxMs */) override; - /// SocketHandler: write to socket. - void performWrites() override; + /// Does SocketHandler: have messages to send ? + bool hasQueuedMessages() const override; + + /// SocketHandler: send those messages + void writeQueuedMessages() override; virtual bool _handleInput(const char* buffer, int length) override; diff --git a/wsd/DocumentBroker.cpp b/wsd/DocumentBroker.cpp index b9e7e983c..8b0c883c0 100644 --- a/wsd/DocumentBroker.cpp +++ b/wsd/DocumentBroker.cpp @@ -1468,6 +1468,7 @@ void DocumentBroker::finalRemoveSession(const std::string& id) // Remove. The caller must have a reference to the session // in question, lest we destroy from underneath them. + it->second->dispose(); _sessions.erase(it); const size_t count = _sessions.size(); @@ -1497,11 +1498,12 @@ void DocumentBroker::finalRemoveSession(const std::string& id) } } -std::shared_ptr<ClientSession> DocumentBroker::createNewClientSession(const WebSocketHandler* ws, - const std::string& id, - const Poco::URI& uriPublic, - const bool isReadOnly, - const std::string& hostNoTrust) +std::shared_ptr<ClientSession> DocumentBroker::createNewClientSession( + const std::shared_ptr<ProtocolHandlerInterface> &ws, + const std::string& id, + const Poco::URI& uriPublic, + const bool isReadOnly, + const std::string& hostNoTrust) { try { @@ -1510,13 +1512,13 @@ std::shared_ptr<ClientSession> DocumentBroker::createNewClientSession(const WebS { const std::string statusReady = "statusindicator: ready"; LOG_TRC("Sending to Client [" << statusReady << "]."); - ws->sendMessage(statusReady); + ws->sendTextMessage(statusReady, statusReady.size()); } // In case of WOPI, if this session is not set as readonly, it might be set so // later after making a call to WOPI host which tells us the permission on files // (UserCanWrite param). - auto session = std::make_shared<ClientSession>(id, shared_from_this(), uriPublic, isReadOnly, hostNoTrust); + auto session = std::make_shared<ClientSession>(ws, id, shared_from_this(), uriPublic, isReadOnly, hostNoTrust); session->construct(); return session; @@ -2252,7 +2254,9 @@ bool ConvertToBroker::startConversion(SocketDisposition &disposition, const std: // Create a session to load the document. const bool isReadOnly = true; - _clientSession = std::make_shared<ClientSession>(id, docBroker, getPublicUri(), isReadOnly, "nocliphost"); + // FIXME: associate this with moveSocket (?) + std::shared_ptr<ProtocolHandlerInterface> nullPtr; + _clientSession = std::make_shared<ClientSession>(nullPtr, id, docBroker, getPublicUri(), isReadOnly, "nocliphost"); _clientSession->construct(); if (!_clientSession) diff --git a/wsd/DocumentBroker.hpp b/wsd/DocumentBroker.hpp index f56bd1e3f..68369d274 100644 --- a/wsd/DocumentBroker.hpp +++ b/wsd/DocumentBroker.hpp @@ -244,11 +244,12 @@ public: void finalRemoveSession(const std::string& id); /// Create new client session - std::shared_ptr<ClientSession> createNewClientSession(const WebSocketHandler* ws, - const std::string& id, - const Poco::URI& uriPublic, - const bool isReadOnly, - const std::string& hostNoTrust); + std::shared_ptr<ClientSession> createNewClientSession( + const std::shared_ptr<ProtocolHandlerInterface> &ws, + const std::string& id, + const Poco::URI& uriPublic, + const bool isReadOnly, + const std::string& hostNoTrust); /// Thread safe termination of this broker if it has a lingering thread void joinThread(); diff --git a/wsd/LOOLWSD.cpp b/wsd/LOOLWSD.cpp index b009bcaa5..241867af7 100644 --- a/wsd/LOOLWSD.cpp +++ b/wsd/LOOLWSD.cpp @@ -236,7 +236,7 @@ namespace { #if ENABLE_SUPPORT_KEY -inline void shutdownLimitReached(WebSocketHandler& ws) +inline void shutdownLimitReached(const std::shared_ptr<WebSocketHandler>& ws) { const std::string error = Poco::format(PAYLOAD_UNAVAILABLE_LIMIT_REACHED, LOOLWSD::MaxDocuments, LOOLWSD::MaxConnections); LOG_INF("Sending client 'hardlimitreached' message: " << error); @@ -244,10 +244,10 @@ inline void shutdownLimitReached(WebSocketHandler& ws) try { // Let the client know we are shutting down. - ws.sendMessage(error); + ws->sendMessage(error); // Shutdown. - ws.shutdown(WebSocketHandler::StatusCodes::POLICY_VIOLATION); + ws->shutdown(WebSocketHandler::StatusCodes::POLICY_VIOLATION); } catch (const std::exception& ex) { @@ -1728,11 +1728,12 @@ std::mutex Connection::Mutex; /// Otherwise, creates and adds a new one to DocBrokers. /// May return null if terminating or MaxDocuments limit is reached. /// After returning a valid instance DocBrokers must be cleaned up after exceptions. -static std::shared_ptr<DocumentBroker> findOrCreateDocBroker(WebSocketHandler& ws, - const std::string& uri, - const std::string& docKey, - const std::string& id, - const Poco::URI& uriPublic) +static std::shared_ptr<DocumentBroker> + findOrCreateDocBroker(const std::shared_ptr<WebSocketHandler>& ws, + const std::string& uri, + const std::string& docKey, + const std::string& id, + const Poco::URI& uriPublic) { LOG_INF("Find or create DocBroker for docKey [" << docKey << "] for session [" << id << "] on url [" << LOOLWSD::anonymizeUrl(uriPublic.toString()) << "]."); @@ -1761,8 +1762,8 @@ static std::shared_ptr<DocumentBroker> findOrCreateDocBroker(WebSocketHandler& w if (docBroker->isMarkedToDestroy()) { LOG_WRN("DocBroker with docKey [" << docKey << "] that is marked to be destroyed. Rejecting client request."); - ws.sendMessage("error: cmd=load kind=docunloading"); - ws.shutdown(WebSocketHandler::StatusCodes::ENDPOINT_GOING_AWAY, "error: cmd=load kind=docunloading"); + ws->sendMessage("error: cmd=load kind=docunloading"); + ws->shutdown(WebSocketHandler::StatusCodes::ENDPOINT_GOING_AWAY, "error: cmd=load kind=docunloading"); return nullptr; } } @@ -1780,7 +1781,7 @@ static std::shared_ptr<DocumentBroker> findOrCreateDocBroker(WebSocketHandler& w // Indicate to the client that we're connecting to the docbroker. const std::string statusConnect = "statusindicator: connect"; LOG_TRC("Sending to Client [" << statusConnect << "]."); - ws.sendMessage(statusConnect); + ws->sendMessage(statusConnect); if (!docBroker) { @@ -1932,6 +1933,11 @@ private: addNewChild(child); }); } + catch (const std::bad_weak_ptr&) + { + // Using shared_from_this() from a constructor is not good. + assert(false); + } catch (const std::exception& exc) { // Probably don't have enough data just yet. @@ -1995,7 +2001,7 @@ public: #endif /// Handles incoming connections and dispatches to the appropriate handler. -class ClientRequestDispatcher : public SocketHandlerInterface +class ClientRequestDispatcher : public SimpleSocketHandler { public: ClientRequestDispatcher() @@ -2780,7 +2786,7 @@ private: LOG_TRC("Client WS request: " << request.getURI() << ", url: " << url << ", socket #" << socket->getFD()); // First Upgrade. - WebSocketHandler ws(_socket, request); + auto ws = std::make_shared<WebSocketHandler>(_socket, request); // Response to clients beyond this point is done via WebSocket. try @@ -2807,7 +2813,7 @@ private: // Indicate to the client that document broker is searching. const std::string status("statusindicator: find"); LOG_TRC("Sending to Client [" << status << "]."); - ws.sendMessage(status); + ws->sendMessage(status); LOG_INF("Sanitized URI [" << LOOLWSD::anonymizeUrl(url) << "] to [" << LOOLWSD::anonymizeUrl(uriPublic.toString()) << "] and mapped to docKey [" << docKey << "] for session [" << _id << "]."); @@ -2837,11 +2843,11 @@ private: #endif std::shared_ptr<ClientSession> clientSession = - docBroker->createNewClientSession(&ws, _id, uriPublic, isReadOnly, hostNoTrust); + docBroker->createNewClientSession(ws, _id, uriPublic, isReadOnly, hostNoTrust); if (clientSession) { // Transfer the client socket to the DocumentBroker when we get back to the poll: - disposition.setMove([docBroker, clientSession] + disposition.setMove([docBroker, clientSession, ws] (const std::shared_ptr<Socket> &moveSocket) { // Make sure the thread is running before adding callback. @@ -2850,16 +2856,16 @@ private: // We no longer own this socket. moveSocket->setThreadOwner(std::thread::id()); - docBroker->addCallback([docBroker, moveSocket, clientSession]() + docBroker->addCallback([docBroker, moveSocket, clientSession, ws]() { try { auto streamSocket = std::static_pointer_cast<StreamSocket>(moveSocket); - // Set the ClientSession to handle Socket events. - streamSocket->setHandler(clientSession); - LOG_DBG("Socket #" << moveSocket->getFD() << " handler is " << clientSession->getName()); + // Set WebSocketHandler's socket after its construction for shared_ptr goodness. + streamSocket->setHandler(ws); + LOG_DBG("Socket #" << moveSocket->getFD() << " handler is " << clientSession->getName()); // Move the socket into DocBroker. docBroker->addSocketToPoll(moveSocket); @@ -2868,7 +2874,8 @@ private: checkDiskSpaceAndWarnClients(true); #if !ENABLE_SUPPORT_KEY - // Users of development versions get just an info when reaching max documents or connections + // Users of development versions get just an info + // when reaching max documents or connections checkSessionLimitsAndWarnClients(); #endif } @@ -2909,8 +2916,8 @@ private: { LOG_ERR("Error while handling Client WS Request: " << exc.what()); const std::string msg = "error: cmd=internal kind=load"; - ws.sendMessage(msg); - ws.shutdown(WebSocketHandler::StatusCodes::ENDPOINT_GOING_AWAY, msg); + ws->sendMessage(msg); + ws->shutdown(WebSocketHandler::StatusCodes::ENDPOINT_GOING_AWAY, msg); } } diff --git a/wsd/TestStubs.cpp b/wsd/TestStubs.cpp index b75499ec2..ca04416da 100644 --- a/wsd/TestStubs.cpp +++ b/wsd/TestStubs.cpp @@ -25,16 +25,16 @@ void ClientSession::enqueueSendMessage(const std::shared_ptr<Message>& /*data*/) ClientSession::~ClientSession() {} -void ClientSession::performWrites() {} - void ClientSession::onDisconnect() {} +bool ClientSession::hasQueuedMessages() const { return false; } + +void ClientSession::writeQueuedMessages() {} + void ClientSession::dumpState(std::ostream& /*os*/) {} void ClientSession::setReadOnly() {} bool ClientSession::_handleInput(const char* /*buffer*/, int /*length*/) { return false; } -int ClientSession::getPollEvents(std::chrono::steady_clock::time_point /* now */, int & /* timeoutMaxMs */) { return 0; } - /* vim:set shiftwidth=4 softtabstop=4 expandtab: */ _______________________________________________ Libreoffice-commits mailing list libreoffice-comm...@lists.freedesktop.org https://lists.freedesktop.org/mailman/listinfo/libreoffice-commits