From: Peter Krystad <peter.krys...@linux.intel.com> Use the MPTCP socket lock to mutually exclude shutdown and close execution.
Since mptcp_close() is the only code path that removes entries from conn_list, we can safely traverse the list while interrupting the RCU critical section. Signed-off-by: Peter Krystad <peter.krys...@linux.intel.com> Signed-off-by: Paolo Abeni <pab...@redhat.com> --- net/mptcp/protocol.c | 222 ++++++++++++++++++++++++++++++++----------- net/mptcp/protocol.h | 9 +- 2 files changed, 172 insertions(+), 59 deletions(-) diff --git a/net/mptcp/protocol.c b/net/mptcp/protocol.c index 2e76b7450ce2..c00e837a1766 100644 --- a/net/mptcp/protocol.c +++ b/net/mptcp/protocol.c @@ -29,34 +29,48 @@ static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) struct mptcp_sock *msk = mptcp_sk(sk); int mss_now, size_goal, poffset, ret; struct mptcp_ext *mpext = NULL; + struct subflow_context *subflow; struct page *page = NULL; + struct hlist_node *node; struct sk_buff *skb; struct sock *ssk; size_t psize; pr_debug("msk=%p", msk); - if (!msk->connection_list && msk->subflow) { + if (msk->subflow) { pr_debug("fallback passthrough"); return sock_sendmsg(msk->subflow, msg); } + rcu_read_lock(); + node = rcu_dereference(hlist_first_rcu(&msk->conn_list)); + subflow = hlist_entry(node, struct subflow_context, node); + ssk = mptcp_subflow_tcp_socket(subflow)->sk; + sock_hold(ssk); + rcu_read_unlock(); + if (!msg_data_left(msg)) { pr_debug("empty send"); - return sock_sendmsg(msk->connection_list, msg); + ret = sock_sendmsg(mptcp_subflow_tcp_socket(subflow), msg); + goto put_out; } - ssk = msk->connection_list->sk; + pr_debug("conn_list->subflow=%p", subflow); - if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL)) - return -ENOTSUPP; + if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL)) { + ret = -ENOTSUPP; + goto put_out; + } /* Initial experiment: new page per send. Real code will * maintain list of active pages and DSS mappings, append to the * end and honor zerocopy */ page = alloc_page(GFP_KERNEL); - if (!page) - return -ENOMEM; + if (!page) { + ret = -ENOMEM; + goto put_out; + } /* Copy to page */ poffset = 0; @@ -68,8 +82,8 @@ static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) pr_debug("left=%zu", msg_data_left(msg)); if (!psize) { - put_page(page); - return -EINVAL; + ret = -EINVAL; + goto put_out; } lock_sock(sk); @@ -87,9 +101,8 @@ static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) ret = do_tcp_sendpages(ssk, page, poffset, min_t(int, size_goal, psize), msg->msg_flags | MSG_SENDPAGE_NOTLAST); - put_page(page); if (ret <= 0) - goto error_out; + goto release_out; if (skb == tcp_write_queue_tail(ssk)) pr_err("no new skb %p/%p", sk, ssk); @@ -117,10 +130,15 @@ static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) tcp_push(ssk, msg->msg_flags, mss_now, tcp_sk(ssk)->nonagle, size_goal); -error_out: +release_out: release_sock(ssk); release_sock(sk); +put_out: + if (page) + put_page(page); + + sock_put(ssk); return ret; } @@ -275,20 +293,26 @@ static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, struct mptcp_sock *msk = mptcp_sk(sk); struct subflow_context *subflow; struct mptcp_read_arg arg; + struct hlist_node *node; read_descriptor_t desc; struct tcp_sock *tp; struct sock *ssk; int copied = 0; long timeo; - if (!msk->connection_list) { + if (msk->subflow) { pr_debug("fallback-read subflow=%p", subflow_ctx(msk->subflow->sk)); return sock_recvmsg(msk->subflow, msg, flags); } - ssk = msk->connection_list->sk; - subflow = subflow_ctx(ssk); + rcu_read_lock(); + node = rcu_dereference(hlist_first_rcu(&msk->conn_list)); + subflow = hlist_entry(node, struct subflow_context, node); + ssk = mptcp_subflow_tcp_socket(subflow)->sk; + sock_hold(ssk); + rcu_read_unlock(); + tp = tcp_sk(ssk); lock_sock(sk); @@ -450,6 +474,8 @@ static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, release_sock(ssk); release_sock(sk); + sock_put(ssk); + return copied; } @@ -459,24 +485,56 @@ static int mptcp_init_sock(struct sock *sk) pr_debug("msk=%p", msk); + INIT_LIST_HEAD_RCU(&msk->conn_list); + spin_lock_init(&msk->conn_list_lock); + return 0; } +static void mptcp_flush_conn_list(struct sock *sk, struct list_head *list) +{ + struct mptcp_sock *msk = mptcp_sk(sk); + + INIT_LIST_HEAD_RCU(list); + spin_lock_bh(&msk->conn_list_lock); + list_splice_init(&msk->conn_list, list); + spin_unlock_bh(&msk->conn_list_lock); + + if (!list_empty(list)) + synchronize_rcu(); +} + static void mptcp_close(struct sock *sk, long timeout) { struct mptcp_sock *msk = mptcp_sk(sk); + struct subflow_context *subflow, *tmp; + struct socket *ssk = NULL; + struct list_head list; inet_sk_state_store(sk, TCP_CLOSE); + spin_lock_bh(&msk->conn_list_lock); if (msk->subflow) { - pr_debug("subflow=%p", subflow_ctx(msk->subflow->sk)); - sock_release(msk->subflow); + ssk = msk->subflow; + msk->subflow = NULL; } + spin_unlock_bh(&msk->conn_list_lock); + if (ssk) { + pr_debug("subflow=%p", ssk->sk); + sock_release(ssk); + } + + /* this is the only place where we can remove any entry from the + * conn_list. Additionally acquiring the socket lock here + * allows for mutual exclusion with mptcp_shutdown(). + */ + lock_sock(sk); + mptcp_flush_conn_list(sk, &list); + release_sock(sk); - if (msk->connection_list) { - pr_debug("conn_list->subflow=%p", - subflow_ctx(msk->connection_list->sk)); - sock_release(msk->connection_list); + list_for_each_entry_safe(subflow, tmp, &list, node) { + pr_debug("conn_list->subflow=%p", subflow); + sock_release(mptcp_subflow_tcp_socket(subflow)); } sock_orphan(sk); @@ -518,7 +576,10 @@ static struct sock *mptcp_accept(struct sock *sk, int flags, int *err, msk->local_key = subflow->local_key; msk->token = subflow->token; token_update_accept(new_sock->sk, new_mptcp_sock->sk); - msk->connection_list = new_sock; + spin_lock_bh(&msk->conn_list_lock); + list_add_rcu(&subflow->node, &msk->conn_list); + msk->subflow = NULL; + spin_unlock_bh(&msk->conn_list_lock); crypto_key_sha1(msk->remote_key, NULL, &ack_seq); msk->write_seq = subflow->idsn + 1; @@ -550,46 +611,46 @@ static int mptcp_setsockopt(struct sock *sk, int level, int optname, char __user *uoptval, unsigned int optlen) { struct mptcp_sock *msk = mptcp_sk(sk); - struct socket *subflow; char __kernel *optval; - pr_debug("msk=%p", msk); - if (msk->connection_list) { - subflow = msk->connection_list; - pr_debug("conn_list->subflow=%p", subflow_ctx(subflow->sk)); - } else { - subflow = msk->subflow; - pr_debug("subflow=%p", subflow_ctx(subflow->sk)); - } - /* will be treated as __user in tcp_setsockopt */ optval = (char __kernel __force *)uoptval; - return kernel_setsockopt(subflow, level, optname, optval, optlen); + pr_debug("msk=%p", msk); + if (msk->subflow) { + pr_debug("subflow=%p", msk->subflow->sk); + return kernel_setsockopt(msk->subflow, level, optname, optval, + optlen); + } + + /* @@ the meaning of setsockopt() when the socket is connected and + * there are multiple subflows is not defined. + */ + return 0; } static int mptcp_getsockopt(struct sock *sk, int level, int optname, char __user *uoptval, int __user *uoption) { struct mptcp_sock *msk = mptcp_sk(sk); - struct socket *subflow; char __kernel *optval; int __kernel *option; - pr_debug("msk=%p", msk); - if (msk->connection_list) { - subflow = msk->connection_list; - pr_debug("conn_list->subflow=%p", subflow_ctx(subflow->sk)); - } else { - subflow = msk->subflow; - pr_debug("subflow=%p", subflow_ctx(subflow->sk)); - } - /* will be treated as __user in tcp_getsockopt */ optval = (char __kernel __force *)uoptval; option = (int __kernel __force *)uoption; - return kernel_getsockopt(subflow, level, optname, optval, option); + pr_debug("msk=%p", msk); + if (msk->subflow) { + pr_debug("subflow=%p", msk->subflow->sk); + return kernel_getsockopt(msk->subflow, level, optname, optval, + option); + } + + /* @@ the meaning of setsockopt() when the socket is connected and + * there are multiple subflows is not defined. + */ + return 0; } static int mptcp_get_port(struct sock *sk, unsigned short snum) @@ -613,8 +674,10 @@ void mptcp_finish_connect(struct sock *sk, int mp_capable) msk->local_key = subflow->local_key; msk->token = subflow->token; pr_debug("msk=%p, token=%u", msk, msk->token); - msk->connection_list = msk->subflow; + spin_lock_bh(&msk->conn_list_lock); + list_add_rcu(&subflow->node, &msk->conn_list); msk->subflow = NULL; + spin_unlock_bh(&msk->conn_list_lock); crypto_key_sha1(msk->remote_key, NULL, &ack_seq); msk->write_seq = subflow->idsn + 1; @@ -715,17 +778,32 @@ static int mptcp_getname(struct socket *sock, struct sockaddr *uaddr, int peer) { struct mptcp_sock *msk = mptcp_sk(sock->sk); - struct socket *subflow; - int err = -EPERM; + struct subflow_context *subflow; + struct hlist_node *node; + struct sock *ssk; + int ret; - if (msk->connection_list) - subflow = msk->connection_list; - else - subflow = msk->subflow; + pr_debug("msk=%p", msk); - err = inet_getname(subflow, uaddr, peer); + if (msk->subflow) { + pr_debug("subflow=%p", msk->subflow->sk); + return inet_getname(msk->subflow, uaddr, peer); + } - return err; + /* @@ the meaning of getname() for the remote peer when the socket + * is connected and there are multiple subflows is not defined. + * For now just use the first subflow on the list. + */ + rcu_read_lock(); + node = rcu_dereference(hlist_first_rcu(&msk->conn_list)); + subflow = hlist_entry(node, struct subflow_context, node); + ssk = mptcp_subflow_tcp_socket(subflow)->sk; + sock_hold(ssk); + rcu_read_unlock(); + + ret = inet_getname(mptcp_subflow_tcp_socket(subflow), uaddr, peer); + sock_put(ssk); + return ret; } static int mptcp_listen(struct socket *sock, int backlog) @@ -760,31 +838,59 @@ static __poll_t mptcp_poll(struct file *file, struct socket *sock, struct poll_table_struct *wait) { const struct mptcp_sock *msk; + struct subflow_context *subflow; struct sock *sk = sock->sk; + struct hlist_node *node; + struct sock *ssk; + __poll_t ret; msk = mptcp_sk(sk); if (msk->subflow) return tcp_poll(file, msk->subflow, wait); - return tcp_poll(file, msk->connection_list, wait); + rcu_read_lock(); + node = rcu_dereference(hlist_first_rcu(&msk->conn_list)); + subflow = hlist_entry(node, struct subflow_context, node); + ssk = mptcp_subflow_tcp_socket(subflow)->sk; + sock_hold(ssk); + rcu_read_unlock(); + + ret = tcp_poll(file, ssk->sk_socket, wait); + sock_put(ssk); + return ret; } static int mptcp_shutdown(struct socket *sock, int how) { struct mptcp_sock *msk = mptcp_sk(sock->sk); + struct subflow_context *subflow; int ret = 0; pr_debug("sk=%p, how=%d", msk, how); if (msk->subflow) { pr_debug("subflow=%p", msk->subflow->sk); - ret = kernel_sock_shutdown(msk->subflow, how); + return kernel_sock_shutdown(msk->subflow, how); } - if (msk->connection_list) { - pr_debug("conn_list->subflow=%p", msk->connection_list->sk); - ret = kernel_sock_shutdown(msk->connection_list, how); + /* protect against concurrent mptcp_close(), so that nobody can + * remove entries from the conn list and walking the list breaking + * the RCU critical section is still safe. We need to release the + * RCU lock to call the blocking kernel_sock_shutdown() primitive + * Note: we can't use MPTCP socket lock to protect conn_list changes, + * as we need to update it from the BH via the mptcp_finish_connect() + */ + lock_sock(sock->sk); + rcu_read_lock(); + list_for_each_entry_rcu(subflow, &msk->conn_list, node) { + pr_debug("conn_list->subflow=%p", subflow); + rcu_read_unlock(); + ret = kernel_sock_shutdown(mptcp_subflow_tcp_socket(subflow), + how); + rcu_read_lock(); } + rcu_read_unlock(); + release_sock(sock->sk); return ret; } diff --git a/net/mptcp/protocol.h b/net/mptcp/protocol.h index 5c840f76a9b9..a1bf093bb37e 100644 --- a/net/mptcp/protocol.h +++ b/net/mptcp/protocol.h @@ -7,6 +7,8 @@ #ifndef __MPTCP_PROTOCOL_H #define __MPTCP_PROTOCOL_H +#include <linux/spinlock.h> + /* MPTCP option subtypes */ #define MPTCPOPT_MP_CAPABLE 0 #define MPTCPOPT_MP_JOIN 1 @@ -52,10 +54,14 @@ struct mptcp_sock { u64 write_seq; u64 ack_seq; u32 token; - struct socket *connection_list; /* @@ needs to be a list */ + spinlock_t conn_list_lock; + struct list_head conn_list; struct socket *subflow; /* outgoing connect/listener/!mp_capable */ }; +#define mptcp_for_each_subflow(__msk, __subflow) \ + list_for_each_entry_rcu(__subflow, &((__msk)->conn_list), node) + static inline struct mptcp_sock *mptcp_sk(const struct sock *sk) { return (struct mptcp_sock *)sk; @@ -83,6 +89,7 @@ struct subflow_request_sock *subflow_rsk(const struct request_sock *rsk) /* MPTCP subflow context */ struct subflow_context { + struct list_head node;/* conn_list of subflows */ u64 local_key; u64 remote_key; u32 token; -- 2.22.0