common/Unit.hpp | 1 net/DelaySocket.cpp | 7 +- net/ServerSocket.hpp | 7 +- net/Socket.cpp | 10 ++++ net/Socket.hpp | 116 +++++++++++++++++++++++++---------------------- net/WebSocketHandler.hpp | 4 - net/loolnb.cpp | 8 +-- test/UnitFuzz.cpp | 1 wsd/ClientSession.cpp | 8 +-- wsd/ClientSession.hpp | 2 wsd/LOOLWSD.cpp | 61 ++++++++++++------------ 11 files changed, 121 insertions(+), 104 deletions(-)
New commits: commit 9e45fb30d7f33b57fda9f615447ae8ac9b920fc1 Author: Michael Meeks <michael.me...@collabora.com> Date: Fri May 5 11:51:43 2017 +0100 SocketDisposition: push it down the stack, and cleanup around that. Dung out overlapping return enumerations. Move more work into 'move' callbacks at a safer time, etc. Change-Id: I62ba5a35f12073b7b9c8de4674be9dae519a8aca diff --git a/common/Unit.hpp b/common/Unit.hpp index e8197fd1..5f8d20ea 100644 --- a/common/Unit.hpp +++ b/common/Unit.hpp @@ -177,6 +177,7 @@ public: /// Intercept incoming requests, so unit tests can silently communicate virtual bool filterHandleRequest( TestRequest /* type */, + SocketDisposition & /* disposition */, WebSocketHandler & /* handler */) { return false; diff --git a/net/DelaySocket.cpp b/net/DelaySocket.cpp index 723357c1..20990e5c 100644 --- a/net/DelaySocket.cpp +++ b/net/DelaySocket.cpp @@ -122,7 +122,8 @@ public: _state = newState; } - HandleResult handlePoll(std::chrono::steady_clock::time_point now, int events) override + void handlePoll(SocketDisposition &disposition, + std::chrono::steady_clock::time_point now, int events) override { if (_state == ReadWrite && (events & POLLIN)) { @@ -215,9 +216,7 @@ public: } if (_state == Closed) - return HandleResult::SOCKET_CLOSED; - else - return HandleResult::CONTINUE; + disposition.setClosed(); } }; diff --git a/net/ServerSocket.hpp b/net/ServerSocket.hpp index 805430ea..4d4bb353 100644 --- a/net/ServerSocket.hpp +++ b/net/ServerSocket.hpp @@ -88,8 +88,9 @@ public: void dumpState(std::ostream& os) override; - HandleResult handlePoll(std::chrono::steady_clock::time_point /* now */, - int events) override + void handlePoll(SocketDisposition &, + std::chrono::steady_clock::time_point /* now */, + int events) override { if (events & POLLIN) { @@ -103,8 +104,6 @@ public: LOG_DBG("Accepted client #" << clientSocket->getFD()); _clientPoller.insertNewSocket(clientSocket); } - - return Socket::HandleResult::CONTINUE; } private: diff --git a/net/Socket.cpp b/net/Socket.cpp index ad912880..5faeced7 100644 --- a/net/Socket.cpp +++ b/net/Socket.cpp @@ -122,6 +122,16 @@ void ServerSocket::dumpState(std::ostream& os) os << "\t" << getFD() << "\t<accept>\n"; } + +void SocketDisposition::execute() +{ + // We should have hard ownership of this socket. + assert(_socket->getThreadOwner() == std::this_thread::get_id()); + if (_socketMove) + _socketMove(_socket); + _socketMove = nullptr; +} + namespace { void dump_hex (const char *legend, const char *prefix, std::vector<char> buffer) diff --git a/net/Socket.hpp b/net/Socket.hpp index 694e82a5..0468eb9c 100644 --- a/net/Socket.hpp +++ b/net/Socket.hpp @@ -44,6 +44,48 @@ namespace Poco } } +class Socket; + +/// Helper to allow us to easily defer the movement of a socket +/// between polls to clarify thread ownership. +class SocketDisposition +{ + typedef std::function<void(const std::shared_ptr<Socket> &)> MoveFunction; + enum class Type { CONTINUE, CLOSED, MOVE }; + + Type _disposition; + MoveFunction _socketMove; + std::shared_ptr<Socket> _socket; + +public: + SocketDisposition(const std::shared_ptr<Socket> &socket) : + _disposition(Type::CONTINUE), + _socket(socket) + {} + ~SocketDisposition() + { + assert (!_socketMove); + } + void setMove() + { + _disposition = Type::MOVE; + } + void setMove(MoveFunction moveFn) + { + _socketMove = moveFn; + _disposition = Type::MOVE; + } + void setClosed() + { + _disposition = Type::CLOSED; + } + bool isMove() { return _disposition == Type::MOVE; } + bool isClosed() { return _disposition == Type::CLOSED; } + + /// Perform the queued up work. + void execute(); +}; + /// A non-blocking, streaming socket. class Socket { @@ -86,8 +128,9 @@ public: int &timeoutMaxMs) = 0; /// Handle results of events returned from poll - enum class HandleResult { CONTINUE, SOCKET_CLOSED, MOVED }; - virtual HandleResult handlePoll(std::chrono::steady_clock::time_point now, int events) = 0; + virtual void handlePoll(SocketDisposition &disposition, + std::chrono::steady_clock::time_point now, + int events) = 0; /// manage latency issues around packet aggregation virtual void setNoDelay() @@ -411,26 +454,30 @@ public: // Fire the poll callbacks and remove dead fds. std::chrono::steady_clock::time_point newNow = std::chrono::steady_clock::now(); + for (int i = static_cast<int>(size) - 1; i >= 0; --i) { - Socket::HandleResult res = Socket::HandleResult::SOCKET_CLOSED; + SocketDisposition disposition(_pollSockets[i]); try { - res = _pollSockets[i]->handlePoll(newNow, _pollFds[i].revents); + _pollSockets[i]->handlePoll(disposition, newNow, + _pollFds[i].revents); } catch (const std::exception& exc) { LOG_ERR("Error while handling poll for socket #" << _pollFds[i].fd << " in " << _name << ": " << exc.what()); + disposition.setClosed(); } - if (res == Socket::HandleResult::SOCKET_CLOSED || - res == Socket::HandleResult::MOVED) + if (disposition.isMove() || disposition.isClosed()) { LOG_DBG("Removing socket #" << _pollFds[i].fd << " (of " << _pollSockets.size() << ") from " << _name); _pollSockets.erase(_pollSockets.begin() + i); } + + disposition.execute(); } } @@ -608,14 +655,8 @@ public: /// Will be called exactly once. virtual void onConnect(const std::shared_ptr<StreamSocket>& socket) = 0; - enum class SocketOwnership - { - UNCHANGED, //< Same socket poll, business as usual. - MOVED //< The socket poll is now different. - }; - /// Called after successful socket reads. - virtual SocketHandlerInterface::SocketOwnership handleIncomingMessage() = 0; + virtual void handleIncomingMessage(SocketDisposition &disposition) = 0; /// Prepare our poll record; adjust @timeoutMaxMs downwards /// for timeouts, based on current time @now. @@ -773,15 +814,16 @@ protected: /// Called when a polling event is received. /// @events is the mask of events that triggered the wake. - HandleResult handlePoll(std::chrono::steady_clock::time_point now, - const int events) override + void handlePoll(SocketDisposition &disposition, + std::chrono::steady_clock::time_point now, + const int events) override { assertCorrectThread(); _socketHandler->checkTimeout(now); if (!events) - return Socket::HandleResult::CONTINUE; + return; // FIXME: need to close input, but not output (?) bool closed = (events & (POLLHUP | POLLERR | POLLNVAL)); @@ -801,8 +843,9 @@ protected: while (!_inBuffer.empty() && oldSize != _inBuffer.size()) { oldSize = _inBuffer.size(); - if (_socketHandler->handleIncomingMessage() == SocketHandlerInterface::SocketOwnership::MOVED) - return Socket::HandleResult::MOVED; + _socketHandler->handleIncomingMessage(disposition); + if (disposition.isMove()) + return; } do @@ -837,8 +880,8 @@ protected: _socketHandler->onDisconnect(); } - return _closed ? HandleResult::SOCKET_CLOSED : - HandleResult::CONTINUE; + if (_closed) + disposition.setClosed(); } /// Override to write data out to socket. @@ -917,39 +960,6 @@ protected: friend class SimpleResponseClient; }; -/// Helper to allow us to easily defer the movement of a socket -/// between polls to clarify thread ownership. -class SocketDisposition -{ - std::shared_ptr<StreamSocket> _socket; - typedef std::function<void(const std::shared_ptr<StreamSocket> &)> MoveFunction; - MoveFunction _socketMove; - SocketHandlerInterface::SocketOwnership _socketOwnership; -public: - SocketDisposition(const std::shared_ptr<StreamSocket> &socket) : - _socket(socket), - _socketOwnership(SocketHandlerInterface::SocketOwnership::UNCHANGED) - {} - ~SocketDisposition() - { - assert (!_socketMove); - } - void setMove(MoveFunction moveFn) - { - _socketMove = moveFn; - _socketOwnership = SocketHandlerInterface::SocketOwnership::MOVED; - } - SocketHandlerInterface::SocketOwnership execute() - { - // We should have hard ownership of this socket. - assert(_socket->getThreadOwner() == std::this_thread::get_id()); - if (_socketMove) - _socketMove(_socket); - _socketMove = nullptr; - return _socketOwnership; - } -}; - namespace HttpHelper { /// Sends file as HTTP response. diff --git a/net/WebSocketHandler.hpp b/net/WebSocketHandler.hpp index 14d97f81..4ff01c36 100644 --- a/net/WebSocketHandler.hpp +++ b/net/WebSocketHandler.hpp @@ -250,7 +250,7 @@ public: } /// Implementation of the SocketHandlerInterface. - virtual SocketHandlerInterface::SocketOwnership handleIncomingMessage() override + virtual void handleIncomingMessage(SocketDisposition&) override { auto socket = _socket.lock(); if (socket == nullptr) @@ -262,8 +262,6 @@ public: while (handleOneIncomingMessage(socket)) ; // can have multiple msgs in one recv'd packet. } - - return SocketHandlerInterface::SocketOwnership::UNCHANGED; } int getPollEvents(std::chrono::steady_clock::time_point now, diff --git a/net/loolnb.cpp b/net/loolnb.cpp index a014173a..e268b067 100644 --- a/net/loolnb.cpp +++ b/net/loolnb.cpp @@ -45,7 +45,7 @@ public: { } - virtual SocketHandlerInterface::SocketOwnership handleIncomingMessage() override + virtual void handleIncomingMessage(SocketDisposition &disposition) override { LOG_TRC("incoming WebSocket message"); if (_wsState == WSState::HTTP) @@ -89,16 +89,16 @@ public: std::string str = oss.str(); socket->_outBuffer.insert(socket->_outBuffer.end(), str.begin(), str.end()); - return SocketHandlerInterface::SocketOwnership::UNCHANGED; + return; } else if (tokens.count() == 2 && tokens[1] == "ws") { upgradeToWebSocket(req); - return SocketHandlerInterface::SocketOwnership::UNCHANGED; + return; } } - return WebSocketHandler::handleIncomingMessage(); + WebSocketHandler::handleIncomingMessage(disposition); } virtual void handleMessage(const bool fin, const WSOpCode code, std::vector<char> &data) override diff --git a/test/UnitFuzz.cpp b/test/UnitFuzz.cpp index 49575b5d..68367884 100644 --- a/test/UnitFuzz.cpp +++ b/test/UnitFuzz.cpp @@ -121,6 +121,7 @@ public: virtual bool filterHandleRequest( TestRequest /* type */, + SocketDisposition & /* disposition */, WebSocketHandler & /* socket */) override { #if 0 // loolnb diff --git a/wsd/ClientSession.cpp b/wsd/ClientSession.cpp index 0bec1538..55b17d64 100644 --- a/wsd/ClientSession.cpp +++ b/wsd/ClientSession.cpp @@ -51,13 +51,13 @@ ClientSession::~ClientSession() LOG_INF("~ClientSession dtor [" << getName() << "], current number of connections: " << curConnections); } -SocketHandlerInterface::SocketOwnership ClientSession::handleIncomingMessage() +void ClientSession::handleIncomingMessage(SocketDisposition &disposition) { if (UnitWSD::get().filterHandleRequest( - UnitWSD::TestRequest::Client, *this)) - return SocketHandlerInterface::SocketOwnership::UNCHANGED; + UnitWSD::TestRequest::Client, disposition, *this)) + return; - return Session::handleIncomingMessage(); + Session::handleIncomingMessage(disposition); } bool ClientSession::_handleInput(const char *buffer, int length) diff --git a/wsd/ClientSession.hpp b/wsd/ClientSession.hpp index b0eefecf..22fad016 100644 --- a/wsd/ClientSession.hpp +++ b/wsd/ClientSession.hpp @@ -30,7 +30,7 @@ public: virtual ~ClientSession(); - SocketHandlerInterface::SocketOwnership handleIncomingMessage() override; + void handleIncomingMessage(SocketDisposition &) override; void setReadOnly() override; diff --git a/wsd/LOOLWSD.cpp b/wsd/LOOLWSD.cpp index 82126289..f06f67ef 100644 --- a/wsd/LOOLWSD.cpp +++ b/wsd/LOOLWSD.cpp @@ -1380,16 +1380,17 @@ private: } /// Called after successful socket reads. - SocketHandlerInterface::SocketOwnership handleIncomingMessage() override + void handleIncomingMessage(SocketDisposition &disposition) override { if (UnitWSD::get().filterHandleRequest( - UnitWSD::TestRequest::Prisoner, *this)) - return SocketHandlerInterface::SocketOwnership::UNCHANGED; + UnitWSD::TestRequest::Prisoner, disposition, *this)) + return; if (_childProcess.lock()) { // FIXME: inelegant etc. - derogate to websocket code - return WebSocketHandler::handleIncomingMessage(); + WebSocketHandler::handleIncomingMessage(disposition); + return; } auto socket = _socket.lock(); @@ -1402,7 +1403,7 @@ private: if (itBody == in.end()) { LOG_TRC("#" << socket->getFD() << " doesn't have enough data yet."); - return SocketHandlerInterface::SocketOwnership::UNCHANGED; + return; } // Skip the marker. @@ -1434,7 +1435,7 @@ private: if (request.getURI().find(NEW_CHILD_URI) != 0) { LOG_ERR("Invalid incoming URI."); - return SocketHandlerInterface::SocketOwnership::UNCHANGED; + return; } // New Child is spawned. @@ -1455,7 +1456,7 @@ private: if (pid <= 0) { LOG_ERR("Invalid PID in child URI [" << request.getURI() << "]."); - return SocketHandlerInterface::SocketOwnership::UNCHANGED; + return; } in.clear(); @@ -1466,24 +1467,21 @@ private: auto child = std::make_shared<ChildProcess>(pid, socket, request); - // Drop pretentions of ownership before adding to the list. - socket->setThreadOwner(std::thread::id(0)); - _childProcess = child; // weak - addNewChild(child); // Remove from prisoner poll since there is no activity // until we attach the childProcess (with this socket) // to a docBroker, which will do the polling. - return SocketHandlerInterface::SocketOwnership::MOVED; + disposition.setMove([child](const std::shared_ptr<Socket> &){ + // Drop pretentions of ownership before adding to the list. + addNewChild(child); + }); } catch (const std::exception& exc) { // Probably don't have enough data just yet. // TODO: timeout if we never get enough. } - - return SocketHandlerInterface::SocketOwnership::UNCHANGED; } /// Prisoner websocket fun ... (for now) @@ -1538,7 +1536,7 @@ private: } /// Called after successful socket reads. - SocketHandlerInterface::SocketOwnership handleIncomingMessage() override + void handleIncomingMessage(SocketDisposition &disposition) override { auto socket = _socket.lock(); std::vector<char>& in = socket->_inBuffer; @@ -1551,7 +1549,7 @@ private: if (itBody == in.end()) { LOG_DBG("#" << socket->getFD() << " doesn't have enough data yet."); - return SocketHandlerInterface::SocketOwnership::UNCHANGED; + return; } // Skip the marker. @@ -1586,17 +1584,16 @@ private: if (contentLength != Poco::Net::HTTPMessage::UNKNOWN_CONTENT_LENGTH && available < contentLength) { LOG_DBG("Not enough content yet: ContentLength: " << contentLength << ", available: " << available); - return SocketHandlerInterface::SocketOwnership::UNCHANGED; + return; } } catch (const std::exception& exc) { // Probably don't have enough data just yet. // TODO: timeout if we never get enough. - return SocketHandlerInterface::SocketOwnership::UNCHANGED; + return; } - SocketDisposition tailDisposition(socket); try { // Routing @@ -1615,7 +1612,7 @@ private: LOG_INF("Admin request: " << request.getURI()); if (AdminSocketHandler::handleInitialRequest(_socket, request)) { - tailDisposition.setMove([](const std::shared_ptr<StreamSocket> &moveSocket){ + disposition.setMove([](const std::shared_ptr<Socket> &moveSocket){ // Hand the socket over to the Admin poll. Admin::instance().insertNewSocket(moveSocket); }); @@ -1644,12 +1641,12 @@ private: reqPathTokens.count() > 0 && reqPathTokens[0] == "lool") { // All post requests have url prefix 'lool'. - handlePostRequest(request, message, tailDisposition); + handlePostRequest(request, message, disposition); } else if (reqPathTokens.count() > 2 && reqPathTokens[0] == "lool" && reqPathTokens[2] == "ws" && request.find("Upgrade") != request.end() && Poco::icompare(request["Upgrade"], "websocket") == 0) { - handleClientWsUpgrade(request, reqPathTokens[1], tailDisposition); + handleClientWsUpgrade(request, reqPathTokens[1], disposition); } else { @@ -1678,7 +1675,6 @@ private: // if we succeeded - remove the request from our input buffer // we expect one request per socket in.erase(in.begin(), itBody); - return tailDisposition.execute(); } int getPollEvents(std::chrono::steady_clock::time_point /* now */, @@ -1819,7 +1815,7 @@ private: } void handlePostRequest(const Poco::Net::HTTPRequest& request, Poco::MemoryInputStream& message, - SocketDisposition &tailDisposition) + SocketDisposition &disposition) { LOG_INF("Post request: [" << request.getURI() << "]"); @@ -1863,8 +1859,8 @@ private: auto clientSession = createNewClientSession(nullptr, _id, uriPublic, docBroker, isReadOnly); if (clientSession) { - tailDisposition.setMove([docBroker, clientSession, format] - (const std::shared_ptr<StreamSocket> &moveSocket) + disposition.setMove([docBroker, clientSession, format] + (const std::shared_ptr<Socket> &moveSocket) { // Perform all of this after removing the socket // Make sure the thread is running before adding callback. @@ -1875,7 +1871,8 @@ private: docBroker->addCallback([docBroker, moveSocket, clientSession, format]() { - clientSession->setSaveAsSocket(moveSocket); + auto streamSocket = std::static_pointer_cast<StreamSocket>(moveSocket); + clientSession->setSaveAsSocket(streamSocket); // Move the socket into DocBroker. docBroker->addSocketToPoll(moveSocket); @@ -2028,7 +2025,7 @@ private: } void handleClientWsUpgrade(const Poco::Net::HTTPRequest& request, const std::string& url, - SocketDisposition &tailDisposition) + SocketDisposition &disposition) { auto socket = _socket.lock(); if (!socket) @@ -2082,8 +2079,8 @@ private: if (clientSession) { // Transfer the client socket to the DocumentBroker when we get back to the poll: - tailDisposition.setMove([docBroker, clientSession] - (const std::shared_ptr<StreamSocket> &moveSocket) + disposition.setMove([docBroker, clientSession] + (const std::shared_ptr<Socket> &moveSocket) { // Make sure the thread is running before adding callback. docBroker->startThread(); @@ -2093,8 +2090,10 @@ private: docBroker->addCallback([docBroker, moveSocket, clientSession]() { + auto streamSocket = std::static_pointer_cast<StreamSocket>(moveSocket); + // Set the ClientSession to handle Socket events. - moveSocket->setHandler(clientSession); + streamSocket->setHandler(clientSession); LOG_DBG("Socket #" << moveSocket->getFD() << " handler is " << clientSession->getName()); // Move the socket into DocBroker. _______________________________________________ Libreoffice-commits mailing list libreoffice-comm...@lists.freedesktop.org https://lists.freedesktop.org/mailman/listinfo/libreoffice-commits