Ensure that the tunnel's socket is always extant while the tunnel
object exists. Hold a ref on the socket until the tunnel is destroyed
and ensure that all tunnel destroy paths go through a common function
(l2tp_tunnel_delete).

Since the tunnel's socket is now guaranteed to exist if the tunnel
exists, we no longer need to use sockfd_lookup via l2tp_sock_to_tunnel
to derive the tunnel from the socket since this is always
sk_user_data.

The tunnel object gains a new closing flag which is protected by a
spinlock. The existing dead flag which is accessed using
test_and_set_bit APIs is no longer used so is removed.

Fixes: 80d84ef3ff1dd ("l2tp: prevent l2tp_tunnel_delete racing with userspace 
close")
---
 net/l2tp/l2tp_core.c | 128 ++++++++++++++++++---------------------------------
 net/l2tp/l2tp_core.h |  26 ++---------
 net/l2tp/l2tp_ip.c   |   5 +-
 net/l2tp/l2tp_ip6.c  |   3 +-
 4 files changed, 52 insertions(+), 110 deletions(-)

diff --git a/net/l2tp/l2tp_core.c b/net/l2tp/l2tp_core.c
index b68ae77e021e..49d6e06099ec 100644
--- a/net/l2tp/l2tp_core.c
+++ b/net/l2tp/l2tp_core.c
@@ -136,51 +136,6 @@ static inline struct l2tp_net *l2tp_pernet(const struct 
net *net)
 
 }
 
-/* Lookup the tunnel socket, possibly involving the fs code if the socket is
- * owned by userspace.  A struct sock returned from this function must be
- * released using l2tp_tunnel_sock_put once you're done with it.
- */
-static struct sock *l2tp_tunnel_sock_lookup(struct l2tp_tunnel *tunnel)
-{
-       int err = 0;
-       struct socket *sock = NULL;
-       struct sock *sk = NULL;
-
-       if (!tunnel)
-               goto out;
-
-       if (tunnel->fd >= 0) {
-               /* Socket is owned by userspace, who might be in the process
-                * of closing it.  Look the socket up using the fd to ensure
-                * consistency.
-                */
-               sock = sockfd_lookup(tunnel->fd, &err);
-               if (sock)
-                       sk = sock->sk;
-       } else {
-               /* Socket is owned by kernelspace */
-               sk = tunnel->sock;
-               sock_hold(sk);
-       }
-
-out:
-       return sk;
-}
-
-/* Drop a reference to a tunnel socket obtained via. l2tp_tunnel_sock_put */
-static void l2tp_tunnel_sock_put(struct sock *sk)
-{
-       struct l2tp_tunnel *tunnel = l2tp_sock_to_tunnel(sk);
-       if (tunnel) {
-               if (tunnel->fd >= 0) {
-                       /* Socket is owned by userspace */
-                       sockfd_put(sk->sk_socket);
-               }
-               sock_put(sk);
-       }
-       sock_put(sk);
-}
-
 /* Session hash list.
  * The session_id SHOULD be random according to RFC2661, but several
  * L2TP implementations (Cisco and Microsoft) use incrementing
@@ -193,6 +148,12 @@ static void l2tp_tunnel_sock_put(struct sock *sk)
        return &tunnel->session_hlist[hash_32(session_id, L2TP_HASH_BITS)];
 }
 
+void l2tp_tunnel_free(struct l2tp_tunnel *tunnel)
+{
+       sock_put(tunnel->sock);
+       /* the tunnel is freed in the socket destructor */
+}
+
 /* Lookup a tunnel. A new reference is held on the returned tunnel. */
 struct l2tp_tunnel *l2tp_tunnel_get(const struct net *net, u32 tunnel_id)
 {
@@ -969,7 +930,7 @@ int l2tp_udp_encap_recv(struct sock *sk, struct sk_buff 
*skb)
 {
        struct l2tp_tunnel *tunnel;
 
-       tunnel = l2tp_sock_to_tunnel(sk);
+       tunnel = l2tp_tunnel(sk);
        if (tunnel == NULL)
                goto pass_up;
 
@@ -977,13 +938,10 @@ int l2tp_udp_encap_recv(struct sock *sk, struct sk_buff 
*skb)
                 tunnel->name, skb->len);
 
        if (l2tp_udp_recv_core(tunnel, skb, tunnel->recv_payload_hook))
-               goto pass_up_put;
+               goto pass_up;
 
-       sock_put(sk);
        return 0;
 
-pass_up_put:
-       sock_put(sk);
 pass_up:
        return 1;
 }
@@ -1214,7 +1172,6 @@ static void l2tp_tunnel_destruct(struct sock *sk)
 
        l2tp_info(tunnel, L2TP_MSG_CONTROL, "%s: closing...\n", tunnel->name);
 
-
        /* Disable udp encapsulation */
        write_lock_bh(&sk->sk_callback_lock);
        switch (tunnel->encap) {
@@ -1239,12 +1196,11 @@ static void l2tp_tunnel_destruct(struct sock *sk)
        list_del_rcu(&tunnel->list);
        spin_unlock_bh(&pn->l2tp_tunnel_list_lock);
 
-       tunnel->sock = NULL;
-       l2tp_tunnel_dec_refcount(tunnel);
-
        /* Call the original destructor */
        if (sk->sk_destruct)
                (*sk->sk_destruct)(sk);
+
+       kfree_rcu(tunnel, rcu);
 end:
        return;
 }
@@ -1305,30 +1261,26 @@ void l2tp_tunnel_closeall(struct l2tp_tunnel *tunnel)
 /* Tunnel socket destroy hook for UDP encapsulation */
 static void l2tp_udp_encap_destroy(struct sock *sk)
 {
-       struct l2tp_tunnel *tunnel = l2tp_sock_to_tunnel(sk);
+       struct l2tp_tunnel *tunnel;
+
+       rcu_read_lock();
+       tunnel = rcu_dereference_sk_user_data(sk);
        if (tunnel) {
-               l2tp_tunnel_closeall(tunnel);
-               sock_put(sk);
+               l2tp_tunnel_delete(tunnel);
        }
+       rcu_read_unlock();
 }
 
 /* Workqueue tunnel deletion function */
 static void l2tp_tunnel_del_work(struct work_struct *work)
 {
-       struct l2tp_tunnel *tunnel = NULL;
-       struct socket *sock = NULL;
-       struct sock *sk = NULL;
-
-       tunnel = container_of(work, struct l2tp_tunnel, del_work);
+       struct l2tp_tunnel *tunnel = container_of(work, struct l2tp_tunnel,
+                                                 del_work);
+       struct sock *sk = tunnel->sock;
+       struct socket *sock = sk->sk_socket;
 
        l2tp_tunnel_closeall(tunnel);
 
-       sk = l2tp_tunnel_sock_lookup(tunnel);
-       if (!sk)
-               goto out;
-
-       sock = sk->sk_socket;
-
        /* If the tunnel socket was created within the kernel, use
         * the sk API to release it here.
         */
@@ -1339,8 +1291,10 @@ static void l2tp_tunnel_del_work(struct work_struct 
*work)
                }
        }
 
-       l2tp_tunnel_sock_put(sk);
-out:
+       /* drop initial ref */
+       l2tp_tunnel_dec_refcount(tunnel);
+
+       /* drop workqueue ref */
        l2tp_tunnel_dec_refcount(tunnel);
 }
 
@@ -1550,6 +1504,7 @@ int l2tp_tunnel_create(struct net *net, int fd, int 
version, u32 tunnel_id, u32
 
        tunnel->magic = L2TP_TUNNEL_MAGIC;
        sprintf(&tunnel->name[0], "tunl %u", tunnel_id);
+       spin_lock_init(&tunnel->lock);
        rwlock_init(&tunnel->hlist_lock);
        tunnel->acpt_newsess = true;
 
@@ -1605,14 +1560,23 @@ int l2tp_tunnel_create(struct net *net, int fd, int 
version, u32 tunnel_id, u32
                setup_udp_tunnel_sock(net, sock, &udp_cfg);
        }
 
+       /* Bump the reference count. The tunnel context is deleted
+        * only when this drops to zero. A reference is also held on
+        * the tunnel socket to ensure that it is not released while
+        * the tunnel is extant. Must be done before sk_destruct is
+        * set.
+        */
+       refcount_set(&tunnel->ref_count, 1);
+       sock_hold(sk);
+       tunnel->sock = sk;
+       tunnel->fd = fd;
+
        /* Hook on the tunnel socket destructor so that we can cleanup
         * if the tunnel socket goes away.
         */
        tunnel->old_sk_destruct = sk->sk_destruct;
        sk->sk_destruct = &l2tp_tunnel_destruct;
 
-       tunnel->sock = sk;
-       tunnel->fd = fd;
        lockdep_set_class_and_name(&sk->sk_lock.slock, &l2tp_socket_class, 
"l2tp_sock");
 
        sk->sk_allocation = GFP_ATOMIC;
@@ -1622,11 +1586,6 @@ int l2tp_tunnel_create(struct net *net, int fd, int 
version, u32 tunnel_id, u32
 
        /* Add tunnel to our list */
        INIT_LIST_HEAD(&tunnel->list);
-
-       /* Bump the reference count. The tunnel context is deleted
-        * only when this drops to zero. Must be done before list insertion
-        */
-       refcount_set(&tunnel->ref_count, 1);
        spin_lock_bh(&pn->l2tp_tunnel_list_lock);
        list_add_rcu(&tunnel->list, &pn->l2tp_tunnel_list);
        spin_unlock_bh(&pn->l2tp_tunnel_list_lock);
@@ -1650,10 +1609,17 @@ int l2tp_tunnel_create(struct net *net, int fd, int 
version, u32 tunnel_id, u32
  */
 void l2tp_tunnel_delete(struct l2tp_tunnel *tunnel)
 {
-       if (!test_and_set_bit(0, &tunnel->dead)) {
-               l2tp_tunnel_inc_refcount(tunnel);
-               queue_work(l2tp_wq, &tunnel->del_work);
+       spin_lock_bh(&tunnel->lock);
+       if (tunnel->closing) {
+               spin_unlock_bh(&tunnel->lock);
+               return;
        }
+       tunnel->closing = true;
+       spin_unlock_bh(&tunnel->lock);
+
+       /* Hold tunnel ref while queued work item is pending */
+       l2tp_tunnel_inc_refcount(tunnel);
+       queue_work(l2tp_wq, &tunnel->del_work);
 }
 EXPORT_SYMBOL_GPL(l2tp_tunnel_delete);
 
@@ -1667,8 +1633,6 @@ void l2tp_session_free(struct l2tp_session *session)
 
        if (tunnel) {
                BUG_ON(tunnel->magic != L2TP_TUNNEL_MAGIC);
-               sock_put(tunnel->sock);
-               session->tunnel = NULL;
                l2tp_tunnel_dec_refcount(tunnel);
        }
 
diff --git a/net/l2tp/l2tp_core.h b/net/l2tp/l2tp_core.h
index 9bbee90e9963..e88ff7895ccb 100644
--- a/net/l2tp/l2tp_core.h
+++ b/net/l2tp/l2tp_core.h
@@ -155,7 +155,8 @@ struct l2tp_tunnel_cfg {
 struct l2tp_tunnel {
        int                     magic;          /* Should be L2TP_TUNNEL_MAGIC 
*/
 
-       unsigned long           dead;
+       bool                    closing;
+       spinlock_t              lock;           /* protect closing */
 
        struct rcu_head rcu;
        rwlock_t                hlist_lock;     /* protect session_hlist */
@@ -214,27 +215,8 @@ static inline void *l2tp_session_priv(struct l2tp_session 
*session)
        return &session->priv[0];
 }
 
-static inline struct l2tp_tunnel *l2tp_sock_to_tunnel(struct sock *sk)
-{
-       struct l2tp_tunnel *tunnel;
-
-       if (sk == NULL)
-               return NULL;
-
-       sock_hold(sk);
-       tunnel = (struct l2tp_tunnel *)(sk->sk_user_data);
-       if (tunnel == NULL) {
-               sock_put(sk);
-               goto out;
-       }
-
-       BUG_ON(tunnel->magic != L2TP_TUNNEL_MAGIC);
-
-out:
-       return tunnel;
-}
-
 struct l2tp_tunnel *l2tp_tunnel_get(const struct net *net, u32 tunnel_id);
+void l2tp_tunnel_free(struct l2tp_tunnel *tunnel);
 
 struct l2tp_session *l2tp_session_get(const struct net *net,
                                      struct l2tp_tunnel *tunnel,
@@ -283,7 +265,7 @@ static inline void l2tp_tunnel_inc_refcount(struct 
l2tp_tunnel *tunnel)
 static inline void l2tp_tunnel_dec_refcount(struct l2tp_tunnel *tunnel)
 {
        if (refcount_dec_and_test(&tunnel->ref_count))
-               kfree_rcu(tunnel, rcu);
+               l2tp_tunnel_free(tunnel);
 }
 
 /* Session reference counts. Incremented when code obtains a reference
diff --git a/net/l2tp/l2tp_ip.c b/net/l2tp/l2tp_ip.c
index 42f3c2f72bf4..a5591bd2fa24 100644
--- a/net/l2tp/l2tp_ip.c
+++ b/net/l2tp/l2tp_ip.c
@@ -242,12 +242,9 @@ static void l2tp_ip_destroy_sock(struct sock *sk)
        rcu_read_lock();
        tunnel = rcu_dereference_sk_user_data(sk);
        if (tunnel) {
-               l2tp_tunnel_closeall(tunnel);
-               sock_put(sk);
+               l2tp_tunnel_delete(tunnel);
        }
        rcu_read_unlock();
-
-       sk_refcnt_debug_dec(sk);
 }
 
 static int l2tp_ip_bind(struct sock *sk, struct sockaddr *uaddr, int addr_len)
diff --git a/net/l2tp/l2tp_ip6.c b/net/l2tp/l2tp_ip6.c
index be4a3eee85a9..de8e7eb7a638 100644
--- a/net/l2tp/l2tp_ip6.c
+++ b/net/l2tp/l2tp_ip6.c
@@ -257,8 +257,7 @@ static void l2tp_ip6_destroy_sock(struct sock *sk)
        rcu_read_lock();
        tunnel = rcu_dereference_sk_user_data(sk);
        if (tunnel) {
-               l2tp_tunnel_closeall(tunnel);
-               sock_put(sk);
+               l2tp_tunnel_delete(tunnel);
        }
        rcu_read_unlock();
 
-- 
1.9.1

Reply via email to