On Sat, Jun 10, 2023 at 12:58:30AM +0000, Bobby Eshleman wrote:
This patch adds support for multi-transport datagrams.

This includes:
- Per-packet lookup of transports when using sendto(sockaddr_vm)
- Selecting H2G or G2H transport using VMADDR_FLAG_TO_HOST and CID in
 sockaddr_vm

To preserve backwards compatibility with VMCI, some important changes
were made. The "transport_dgram" / VSOCK_TRANSPORT_F_DGRAM is changed to
be used for dgrams iff there is not yet a g2h or h2g transport that has

s/iff/if

been registered that can transmit the packet. If there is a g2h/h2g
transport for that remote address, then that transport will be used and
not "transport_dgram". This essentially makes "transport_dgram" a
fallback transport for when h2g/g2h has not yet gone online, which
appears to be the exact use case for VMCI.

This design makes sense, because there is no reason that the
transport_{g2h,h2g} cannot also service datagrams, which makes the role
of transport_dgram difficult to understand outside of the VMCI context.

The logic around "transport_dgram" had to be retained to prevent
breaking VMCI:

1) VMCI datagrams appear to function outside of the h2g/g2h
  paradigm. When the vmci transport becomes online, it registers itself
  with the DGRAM feature, but not H2G/G2H. Only later when the
  transport has more information about its environment does it register
  H2G or G2H. In the case that a datagram socket becomes active
  after DGRAM registration but before G2H/H2G registration, the
  "transport_dgram" transport needs to be used.

IIRC we did this, because at that time only VMCI supported DGRAM. Now that there are more transports, maybe DGRAM can follow the h2g/g2h paradigm.


2) VMCI seems to require special message be sent by the transport when a
  datagram socket calls bind(). Under the h2g/g2h model, the transport
  is selected using the remote_addr which is set by connect(). At
  bind time there is no remote_addr because often no connect() has been
  called yet: the transport is null. Therefore, with a null transport
  there doesn't seem to be any good way for a datagram socket a tell the
  VMCI transport that it has just had bind() called upon it.

@Vishnu, @Bryan do you think we can avoid this in some way?


Only transports with a special datagram fallback use-case such as VMCI
need to register VSOCK_TRANSPORT_F_DGRAM.

Maybe we should rename it in VSOCK_TRANSPORT_F_DGRAM_FALLBACK or
something like that.

In any case, we definitely need to update the comment in include/net/af_vsock.h on top of VSOCK_TRANSPORT_F_DGRAM mentioning
this.


Signed-off-by: Bobby Eshleman <bobby.eshle...@bytedance.com>
---
drivers/vhost/vsock.c                   |  1 -
include/linux/virtio_vsock.h            |  2 -
net/vmw_vsock/af_vsock.c                | 78 +++++++++++++++++++++++++--------
net/vmw_vsock/hyperv_transport.c        |  6 ---
net/vmw_vsock/virtio_transport.c        |  1 -
net/vmw_vsock/virtio_transport_common.c |  7 ---
net/vmw_vsock/vsock_loopback.c          |  1 -
7 files changed, 60 insertions(+), 36 deletions(-)

diff --git a/drivers/vhost/vsock.c b/drivers/vhost/vsock.c
index c8201c070b4b..8f0082da5e70 100644
--- a/drivers/vhost/vsock.c
+++ b/drivers/vhost/vsock.c
@@ -410,7 +410,6 @@ static struct virtio_transport vhost_transport = {
                .cancel_pkt               = vhost_transport_cancel_pkt,

                .dgram_enqueue            = virtio_transport_dgram_enqueue,
-               .dgram_bind               = virtio_transport_dgram_bind,
                .dgram_allow              = virtio_transport_dgram_allow,
                .dgram_get_cid            = virtio_transport_dgram_get_cid,
                .dgram_get_port           = virtio_transport_dgram_get_port,
diff --git a/include/linux/virtio_vsock.h b/include/linux/virtio_vsock.h
index 23521a318cf0..73afa09f4585 100644
--- a/include/linux/virtio_vsock.h
+++ b/include/linux/virtio_vsock.h
@@ -216,8 +216,6 @@ void virtio_transport_notify_buffer_size(struct vsock_sock 
*vsk, u64 *val);
u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk);
bool virtio_transport_stream_is_active(struct vsock_sock *vsk);
bool virtio_transport_stream_allow(u32 cid, u32 port);
-int virtio_transport_dgram_bind(struct vsock_sock *vsk,
-                               struct sockaddr_vm *addr);
bool virtio_transport_dgram_allow(u32 cid, u32 port);
int virtio_transport_dgram_get_cid(struct sk_buff *skb, unsigned int *cid);
int virtio_transport_dgram_get_port(struct sk_buff *skb, unsigned int *port);
diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c
index 74358f0b47fa..ef86765f3765 100644
--- a/net/vmw_vsock/af_vsock.c
+++ b/net/vmw_vsock/af_vsock.c
@@ -438,6 +438,18 @@ vsock_connectible_lookup_transport(unsigned int cid, __u8 
flags)
        return transport;
}

+static const struct vsock_transport *
+vsock_dgram_lookup_transport(unsigned int cid, __u8 flags)
+{
+       const struct vsock_transport *transport;
+
+       transport = vsock_connectible_lookup_transport(cid, flags);
+       if (transport)
+               return transport;
+
+       return transport_dgram;
+}
+
/* Assign a transport to a socket and call the .init transport callback.
 *
 * Note: for connection oriented socket this must be called when 
vsk->remote_addr
@@ -474,7 +486,8 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct 
vsock_sock *psk)

        switch (sk->sk_type) {
        case SOCK_DGRAM:
-               new_transport = transport_dgram;
+               new_transport = vsock_dgram_lookup_transport(remote_cid,
+                                                            remote_flags);
                break;
        case SOCK_STREAM:
        case SOCK_SEQPACKET:
@@ -691,6 +704,9 @@ static int __vsock_bind_connectible(struct vsock_sock *vsk,
static int __vsock_bind_dgram(struct vsock_sock *vsk,
                              struct sockaddr_vm *addr)
{
+       if (!vsk->transport || !vsk->transport->dgram_bind)
+               return -EINVAL;
+
        return vsk->transport->dgram_bind(vsk, addr);
}

@@ -1172,19 +1188,24 @@ static int vsock_dgram_sendmsg(struct socket *sock, 
struct msghdr *msg,

        lock_sock(sk);

-       transport = vsk->transport;
-
-       err = vsock_auto_bind(vsk);
-       if (err)
-               goto out;
-
-
        /* If the provided message contains an address, use that.  Otherwise
         * fall back on the socket's remote handle (if it has been connected).
         */
        if (msg->msg_name &&
            vsock_addr_cast(msg->msg_name, msg->msg_namelen,
                            &remote_addr) == 0) {
+               transport = vsock_dgram_lookup_transport(remote_addr->svm_cid,
+                                                        
remote_addr->svm_flags);
+               if (!transport) {
+                       err = -EINVAL;
+                       goto out;
+               }
+
+               if (!try_module_get(transport->module)) {
+                       err = -ENODEV;
+                       goto out;
+               }
+
                /* Ensure this address is of the right type and is a valid
                 * destination.
                 */
@@ -1193,11 +1214,27 @@ static int vsock_dgram_sendmsg(struct socket *sock, 
struct msghdr *msg,
                        remote_addr->svm_cid = transport->get_local_cid();


From here ...

                if (!vsock_addr_bound(remote_addr)) {
+                       module_put(transport->module);
+                       err = -EINVAL;
+                       goto out;
+               }
+
+               if (!transport->dgram_allow(remote_addr->svm_cid,
+                                           remote_addr->svm_port)) {
+                       module_put(transport->module);
                        err = -EINVAL;
                        goto out;
                }
+
+               err = transport->dgram_enqueue(vsk, remote_addr, msg, len);

... to here, looks like duplicate code, can we get it out of the if block?

+               module_put(transport->module);
        } else if (sock->state == SS_CONNECTED) {
                remote_addr = &vsk->remote_addr;
+               transport = vsk->transport;
+
+               err = vsock_auto_bind(vsk);
+               if (err)
+                       goto out;

                if (remote_addr->svm_cid == VMADDR_CID_ANY)
                        remote_addr->svm_cid = transport->get_local_cid();
@@ -1205,23 +1242,23 @@ static int vsock_dgram_sendmsg(struct socket *sock, 
struct msghdr *msg,
                /* XXX Should connect() or this function ensure remote_addr is
                 * bound?
                 */
-               if (!vsock_addr_bound(&vsk->remote_addr)) {
+               if (!vsock_addr_bound(remote_addr)) {
                        err = -EINVAL;
                        goto out;
                }
-       } else {
-               err = -EINVAL;
-               goto out;
-       }

-       if (!transport->dgram_allow(remote_addr->svm_cid,
-                                   remote_addr->svm_port)) {
+               if (!transport->dgram_allow(remote_addr->svm_cid,
+                                           remote_addr->svm_port)) {
+                       err = -EINVAL;
+                       goto out;
+               }
+
+               err = transport->dgram_enqueue(vsk, remote_addr, msg, len);
+       } else {
                err = -EINVAL;
                goto out;
        }

-       err = transport->dgram_enqueue(vsk, remote_addr, msg, len);
-
out:
        release_sock(sk);
        return err;
@@ -1255,13 +1292,18 @@ static int vsock_dgram_connect(struct socket *sock,
        if (err)
                goto out;

+       memcpy(&vsk->remote_addr, remote_addr, sizeof(vsk->remote_addr));
+
+       err = vsock_assign_transport(vsk, NULL);
+       if (err)
+               goto out;
+
        if (!vsk->transport->dgram_allow(remote_addr->svm_cid,
                                         remote_addr->svm_port)) {
                err = -EINVAL;
                goto out;
        }

-       memcpy(&vsk->remote_addr, remote_addr, sizeof(vsk->remote_addr));
        sock->state = SS_CONNECTED;

        /* sock map disallows redirection of non-TCP sockets with sk_state !=
diff --git a/net/vmw_vsock/hyperv_transport.c b/net/vmw_vsock/hyperv_transport.c
index ff6e87e25fa0..c00bc5da769a 100644
--- a/net/vmw_vsock/hyperv_transport.c
+++ b/net/vmw_vsock/hyperv_transport.c
@@ -551,11 +551,6 @@ static void hvs_destruct(struct vsock_sock *vsk)
        kfree(hvs);
}

-static int hvs_dgram_bind(struct vsock_sock *vsk, struct sockaddr_vm *addr)
-{
-       return -EOPNOTSUPP;
-}
-
static int hvs_dgram_get_cid(struct sk_buff *skb, unsigned int *cid)
{
        return -EOPNOTSUPP;
@@ -841,7 +836,6 @@ static struct vsock_transport hvs_transport = {
        .connect                  = hvs_connect,
        .shutdown                 = hvs_shutdown,

-       .dgram_bind               = hvs_dgram_bind,
        .dgram_get_cid            = hvs_dgram_get_cid,
        .dgram_get_port           = hvs_dgram_get_port,
        .dgram_get_length         = hvs_dgram_get_length,
diff --git a/net/vmw_vsock/virtio_transport.c b/net/vmw_vsock/virtio_transport.c
index 5763cdf13804..1b7843a7779a 100644
--- a/net/vmw_vsock/virtio_transport.c
+++ b/net/vmw_vsock/virtio_transport.c
@@ -428,7 +428,6 @@ static struct virtio_transport virtio_transport = {
                .shutdown                 = virtio_transport_shutdown,
                .cancel_pkt               = virtio_transport_cancel_pkt,

-               .dgram_bind               = virtio_transport_dgram_bind,
                .dgram_enqueue            = virtio_transport_dgram_enqueue,
                .dgram_allow              = virtio_transport_dgram_allow,
                .dgram_get_cid            = virtio_transport_dgram_get_cid,
diff --git a/net/vmw_vsock/virtio_transport_common.c 
b/net/vmw_vsock/virtio_transport_common.c
index e6903c719964..d5a3c8efe84b 100644
--- a/net/vmw_vsock/virtio_transport_common.c
+++ b/net/vmw_vsock/virtio_transport_common.c
@@ -790,13 +790,6 @@ bool virtio_transport_stream_allow(u32 cid, u32 port)
}
EXPORT_SYMBOL_GPL(virtio_transport_stream_allow);

-int virtio_transport_dgram_bind(struct vsock_sock *vsk,
-                               struct sockaddr_vm *addr)
-{
-       return -EOPNOTSUPP;
-}
-EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind);
-
int virtio_transport_dgram_get_cid(struct sk_buff *skb, unsigned int *cid)
{
        return -EOPNOTSUPP;
diff --git a/net/vmw_vsock/vsock_loopback.c b/net/vmw_vsock/vsock_loopback.c
index 2f3cabc79ee5..e9de45a26fbd 100644
--- a/net/vmw_vsock/vsock_loopback.c
+++ b/net/vmw_vsock/vsock_loopback.c
@@ -61,7 +61,6 @@ static struct virtio_transport loopback_transport = {
                .shutdown                 = virtio_transport_shutdown,
                .cancel_pkt               = vsock_loopback_cancel_pkt,

-               .dgram_bind               = virtio_transport_dgram_bind,
                .dgram_enqueue            = virtio_transport_dgram_enqueue,
                .dgram_allow              = virtio_transport_dgram_allow,
                .dgram_get_cid            = virtio_transport_dgram_get_cid,

--
2.30.2


The rest LGTM!

Stefano

_______________________________________________
Virtualization mailing list
Virtualization@lists.linux-foundation.org
https://lists.linuxfoundation.org/mailman/listinfo/virtualization

Reply via email to