vsock_poll() reads vsk->peer_shutdown before taking the socket lock
to set EPOLLHUP and EPOLLRDHUP, then reads it again after taking
the lock to report EOF readability. A shutdown packet can update
peer_shutdown while poll is waiting for the lock, so one poll invocation
can report EOF readability without the corresponding HUP/RDHUP bits.

For connectible sockets, take one peer_shutdown snapshot after
lock_sock() and use it for all peer-shutdown-derived poll bits. For
datagram sockets, which do not take lock_sock() in poll(), take one
lockless READ_ONCE() snapshot and pair it with WRITE_ONCE() on the
writer side.

This keeps the peer-shutdown-derived bits internally consistent for each
poll pass.

Fixes: d021c344051a ("VSOCK: Introduce VM Sockets")
Signed-off-by: Ziyu Zhang <[email protected]>
---
Link: 
https://lore.kernel.org/netdev/[email protected]/

v2:
- Pair lockless READ_ONCE() users with WRITE_ONCE() on peer_shutdown writers.
- Move datagram shutdown handling into the SOCK_DGRAM branch and add a comment.
- Keep one connectible peer_shutdown snapshot after lock_sock() and
  restore the previous shutdown-derived mask ordering.

 net/vmw_vsock/af_vsock.c                | 49 ++++++++++++++++---------
 net/vmw_vsock/hyperv_transport.c        |  9 +++--
 net/vmw_vsock/virtio_transport_common.c | 14 ++++---
 net/vmw_vsock/vmci_transport.c          |  8 ++--
 4 files changed, 52 insertions(+), 28 deletions(-)

diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c
index adcba1b7b..789b00f6e 100644
--- a/net/vmw_vsock/af_vsock.c
+++ b/net/vmw_vsock/af_vsock.c
@@ -523,7 +523,7 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct 
vsock_sock *psk)
                 */
                sock_reset_flag(sk, SOCK_DONE);
                sk->sk_state = TCP_CLOSE;
-               vsk->peer_shutdown = 0;
+               WRITE_ONCE(vsk->peer_shutdown, 0);
        }
 
        if (sk->sk_type == SOCK_SEQPACKET) {
@@ -814,7 +814,7 @@ static struct sock *__vsock_create(struct net *net,
        vsk->rejected = false;
        vsk->sent_request = false;
        vsk->ignore_connecting_rst = false;
-       vsk->peer_shutdown = 0;
+       WRITE_ONCE(vsk->peer_shutdown, 0);
        INIT_DELAYED_WORK(&vsk->connect_work, vsock_connect_timeout);
        INIT_DELAYED_WORK(&vsk->pending_work, vsock_pending_work);
 
@@ -1122,6 +1122,25 @@ static int vsock_shutdown(struct socket *sock, int mode)
        return err;
 }
 
+static __poll_t vsock_poll_shutdown(struct sock *sk, u32 peer_shutdown)
+{
+       __poll_t mask = 0;
+
+       /* INET sockets treat local write shutdown and peer write shutdown as a
+        * case of EPOLLHUP set.
+        */
+       if (sk->sk_shutdown == SHUTDOWN_MASK ||
+           ((sk->sk_shutdown & SEND_SHUTDOWN) &&
+            (peer_shutdown & SEND_SHUTDOWN)))
+               mask |= EPOLLHUP;
+
+       if (sk->sk_shutdown & RCV_SHUTDOWN ||
+           peer_shutdown & SEND_SHUTDOWN)
+               mask |= EPOLLRDHUP;
+
+       return mask;
+}
+
 static __poll_t vsock_poll(struct file *file, struct socket *sock,
                               poll_table *wait)
 {
@@ -1139,24 +1158,17 @@ static __poll_t vsock_poll(struct file *file, struct 
socket *sock,
                /* Signify that there has been an error on this socket. */
                mask |= EPOLLERR;
 
-       /* INET sockets treat local write shutdown and peer write shutdown as a
-        * case of EPOLLHUP set.
-        */
-       if ((sk->sk_shutdown == SHUTDOWN_MASK) ||
-           ((sk->sk_shutdown & SEND_SHUTDOWN) &&
-            (vsk->peer_shutdown & SEND_SHUTDOWN))) {
-               mask |= EPOLLHUP;
-       }
-
-       if (sk->sk_shutdown & RCV_SHUTDOWN ||
-           vsk->peer_shutdown & SEND_SHUTDOWN) {
-               mask |= EPOLLRDHUP;
-       }
-
        if (sk_is_readable(sk))
                mask |= EPOLLIN | EPOLLRDNORM;
 
        if (sock->type == SOCK_DGRAM) {
+               u32 peer_shutdown = READ_ONCE(vsk->peer_shutdown);
+
+               /* DGRAM sockets do not take lock_sock() in poll(), so use one
+                * lockless snapshot for all shutdown-derived mask bits.
+                */
+               mask |= vsock_poll_shutdown(sk, peer_shutdown);
+
                /* For datagram sockets we can read if there is something in
                 * the queue and write as long as the socket isn't shutdown for
                 * sending.
@@ -1171,6 +1183,7 @@ static __poll_t vsock_poll(struct file *file, struct 
socket *sock,
 
        } else if (sock_type_connectible(sk->sk_type)) {
                const struct vsock_transport *transport;
+               u32 peer_shutdown;
 
                lock_sock(sk);
 
@@ -1203,8 +1216,10 @@ static __poll_t vsock_poll(struct file *file, struct 
socket *sock,
                 * terminated should also be considered read, and we check the
                 * shutdown flag for that.
                 */
+               peer_shutdown = READ_ONCE(vsk->peer_shutdown);
+               mask |= vsock_poll_shutdown(sk, peer_shutdown);
                if (sk->sk_shutdown & RCV_SHUTDOWN ||
-                   vsk->peer_shutdown & SEND_SHUTDOWN) {
+                   peer_shutdown & SEND_SHUTDOWN) {
                        mask |= EPOLLIN | EPOLLRDNORM;
                }
 
diff --git a/net/vmw_vsock/hyperv_transport.c b/net/vmw_vsock/hyperv_transport.c
index 432fcbbd1..16b981566 100644
--- a/net/vmw_vsock/hyperv_transport.c
+++ b/net/vmw_vsock/hyperv_transport.c
@@ -264,7 +264,7 @@ static void hvs_do_close_lock_held(struct vsock_sock *vsk,
        struct sock *sk = sk_vsock(vsk);
 
        sock_set_flag(sk, SOCK_DONE);
-       vsk->peer_shutdown = SHUTDOWN_MASK;
+       WRITE_ONCE(vsk->peer_shutdown, SHUTDOWN_MASK);
        if (vsock_stream_has_data(vsk) <= 0)
                sk->sk_state = TCP_CLOSING;
        sk->sk_state_change(sk);
@@ -593,7 +593,9 @@ static int hvs_update_recv_data(struct hvsock *hvs)
                return -EIO;
 
        if (payload_len == 0)
-               hvs->vsk->peer_shutdown |= SEND_SHUTDOWN;
+               WRITE_ONCE(hvs->vsk->peer_shutdown,
+                          READ_ONCE(hvs->vsk->peer_shutdown) |
+                          SEND_SHUTDOWN);
 
        hvs->recv_data_len = payload_len;
        hvs->recv_data_off = 0;
@@ -715,7 +717,8 @@ static s64 hvs_stream_has_data(struct vsock_sock *vsk)
                        return ret;
                return hvs->recv_data_len;
        case 0:
-               vsk->peer_shutdown |= SEND_SHUTDOWN;
+               WRITE_ONCE(vsk->peer_shutdown,
+                          READ_ONCE(vsk->peer_shutdown) | SEND_SHUTDOWN);
                ret = 0;
                break;
        default: /* -1 */
diff --git a/net/vmw_vsock/virtio_transport_common.c 
b/net/vmw_vsock/virtio_transport_common.c
index dcc8a1d58..71d8eac82 100644
--- a/net/vmw_vsock/virtio_transport_common.c
+++ b/net/vmw_vsock/virtio_transport_common.c
@@ -1220,7 +1220,7 @@ static void virtio_transport_do_close(struct vsock_sock 
*vsk,
        struct sock *sk = sk_vsock(vsk);
 
        sock_set_flag(sk, SOCK_DONE);
-       vsk->peer_shutdown = SHUTDOWN_MASK;
+       WRITE_ONCE(vsk->peer_shutdown, SHUTDOWN_MASK);
        if (vsock_stream_has_data(vsk) <= 0)
                sk->sk_state = TCP_CLOSING;
        sk->sk_state_change(sk);
@@ -1411,12 +1411,15 @@ virtio_transport_recv_connected(struct sock *sk,
        case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
                sk->sk_write_space(sk);
                break;
-       case VIRTIO_VSOCK_OP_SHUTDOWN:
+       case VIRTIO_VSOCK_OP_SHUTDOWN: {
+               u32 peer_shutdown = READ_ONCE(vsk->peer_shutdown);
+
                if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SHUTDOWN_RCV)
-                       vsk->peer_shutdown |= RCV_SHUTDOWN;
+                       peer_shutdown |= RCV_SHUTDOWN;
                if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SHUTDOWN_SEND)
-                       vsk->peer_shutdown |= SEND_SHUTDOWN;
-               if (vsk->peer_shutdown == SHUTDOWN_MASK) {
+                       peer_shutdown |= SEND_SHUTDOWN;
+               WRITE_ONCE(vsk->peer_shutdown, peer_shutdown);
+               if (peer_shutdown == SHUTDOWN_MASK) {
                        if (vsock_stream_has_data(vsk) <= 0 && !sock_flag(sk, 
SOCK_DONE)) {
                                (void)virtio_transport_reset(vsk, NULL);
                                virtio_transport_do_close(vsk, true);
@@ -1431,6 +1434,7 @@ virtio_transport_recv_connected(struct sock *sk,
                if (le32_to_cpu(virtio_vsock_hdr(skb)->flags))
                        sk->sk_state_change(sk);
                break;
+       }
        case VIRTIO_VSOCK_OP_RST:
                virtio_transport_do_close(vsk, true);
                break;
diff --git a/net/vmw_vsock/vmci_transport.c b/net/vmw_vsock/vmci_transport.c
index 7eccd6708..c2231c402 100644
--- a/net/vmw_vsock/vmci_transport.c
+++ b/net/vmw_vsock/vmci_transport.c
@@ -811,7 +811,7 @@ static void vmci_transport_handle_detach(struct sock *sk)
                /* On a detach the peer will not be sending or receiving
                 * anymore.
                 */
-               vsk->peer_shutdown = SHUTDOWN_MASK;
+               WRITE_ONCE(vsk->peer_shutdown, SHUTDOWN_MASK);
 
                /* We should not be sending anymore since the peer won't be
                 * there to receive, but we can still receive if there is data
@@ -1534,7 +1534,9 @@ static int vmci_transport_recv_connected(struct sock *sk,
                if (pkt->u.mode) {
                        vsk = vsock_sk(sk);
 
-                       vsk->peer_shutdown |= pkt->u.mode;
+                       WRITE_ONCE(vsk->peer_shutdown,
+                                  READ_ONCE(vsk->peer_shutdown) |
+                                  pkt->u.mode);
                        sk->sk_state_change(sk);
                }
                break;
@@ -1551,7 +1553,7 @@ static int vmci_transport_recv_connected(struct sock *sk,
                 * a clean shutdown.
                 */
                sock_set_flag(sk, SOCK_DONE);
-               vsk->peer_shutdown = SHUTDOWN_MASK;
+               WRITE_ONCE(vsk->peer_shutdown, SHUTDOWN_MASK);
                if (vsock_stream_has_data(vsk) <= 0)
                        sk->sk_state = TCP_CLOSING;
 
-- 
2.43.0


Reply via email to