On Wed, Mar 19, 2025 at 12:00:38PM -0700, Bobby Eshleman wrote:
On Wed, Mar 19, 2025 at 02:02:32PM +0100, Stefano Garzarella wrote:
On Wed, Mar 12, 2025 at 01:59:35PM -0700, Bobby Eshleman wrote:
> From: Stefano Garzarella <sgarz...@redhat.com>
>
> This patch adds a check of the "net" assigned to a socket during
> the vsock_find_bound_socket() and vsock_find_connected_socket()
> to support network namespace, allowing to share the same address
> (cid, port) across different network namespaces.
>
> This patch preserves old behavior, and does not yet bring up namespace
> support fully.
>
> Signed-off-by: Stefano Garzarella <sgarz...@redhat.com>

I'd describe here a bit the new behaviour related to `fallback` that you
developed.

Or we can split this patch in two patches, one with my changes without
fallback, and another with fallback as you as author.

WDYT?


I like the idea of splitting it, that way any unforeseen issues in the
new logic can be isolated to the one patch.


> Signed-off-by: Bobby Eshleman <bobbyeshle...@gmail.com>
> ---
> v1 -> v2:
> * remove 'netns' module param
> * remove vsock_net_eq()
> * use vsock_global_net() for "global" namespace
> * use fallback logic in socket lookup functions, giving precedence to
>  non-global vsock namespaces
>
> RFC -> v1
> * added 'netns' module param
> * added 'vsock_net_eq()' to check the "net" assigned to a socket
>  only when 'netns' support is enabled
> ---
> include/net/af_vsock.h                  |  7 +++--
> net/vmw_vsock/af_vsock.c                | 55 ++++++++++++++++++++++++---------
> net/vmw_vsock/hyperv_transport.c        |  2 +-
> net/vmw_vsock/virtio_transport_common.c |  5 +--
> net/vmw_vsock/vmci_transport.c          |  4 +--
> 5 files changed, 51 insertions(+), 22 deletions(-)
>
> diff --git a/include/net/af_vsock.h b/include/net/af_vsock.h
> index 
9e85424c834353d016a527070dd62e15ff3bfce1..41afbc18648c953da27a93571d408de968aa7668 
100644
> --- a/include/net/af_vsock.h
> +++ b/include/net/af_vsock.h
> @@ -213,9 +213,10 @@ void vsock_enqueue_accept(struct sock *listener, struct 
sock *connected);
> void vsock_insert_connected(struct vsock_sock *vsk);
> void vsock_remove_bound(struct vsock_sock *vsk);
> void vsock_remove_connected(struct vsock_sock *vsk);
> -struct sock *vsock_find_bound_socket(struct sockaddr_vm *addr);
> +struct sock *vsock_find_bound_socket(struct sockaddr_vm *addr, struct net 
*net);
> struct sock *vsock_find_connected_socket(struct sockaddr_vm *src,
> -                                   struct sockaddr_vm *dst);
> +                                   struct sockaddr_vm *dst,
> +                                   struct net *net);
> void vsock_remove_sock(struct vsock_sock *vsk);
> void vsock_for_each_connected_socket(struct vsock_transport *transport,
>                                 void (*fn)(struct sock *sk));
> @@ -255,4 +256,6 @@ static inline bool vsock_msgzerocopy_allow(const struct 
vsock_transport *t)
> {
>    return t->msgzerocopy_allow && t->msgzerocopy_allow();
> }
> +
> +struct net *vsock_global_net(void);

If it just returns null, maybe we can make it inline here.


Roger that.

> #endif /* __AF_VSOCK_H__ */
> diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c
> index 
7e3db87ae4333cf63327ec105ca99253569bb9fe..d206489bf0a81cf989387c7c8063be91a7c21a7d 
100644
> --- a/net/vmw_vsock/af_vsock.c
> +++ b/net/vmw_vsock/af_vsock.c
> @@ -235,37 +235,60 @@ static void __vsock_remove_connected(struct vsock_sock 
*vsk)
>    sock_put(&vsk->sk);
> }
>
> -static struct sock *__vsock_find_bound_socket(struct sockaddr_vm *addr)
> +struct net *vsock_global_net(void)
> {
> +  return NULL;
> +}
> +EXPORT_SYMBOL_GPL(vsock_global_net);
> +
> +static struct sock *__vsock_find_bound_socket(struct sockaddr_vm *addr,
> +                                        struct net *net)
> +{

Please add a comment here to describe what fallback is used for.
And I would suggest also something on top of this file to explain a bit
how netns are handled in AF_VSOCK.


sgtm!

> +  struct sock *fallback = NULL;
>    struct vsock_sock *vsk;
>
>    list_for_each_entry(vsk, vsock_bound_sockets(addr), bound_table) {
> -          if (vsock_addr_equals_addr(addr, &vsk->local_addr))
> -                  return sk_vsock(vsk);
> +          if (vsock_addr_equals_addr(addr, &vsk->local_addr)) {
> +                  if (net_eq(net, sock_net(sk_vsock(vsk))))
> +                          return sk_vsock(vsk);
>
> +                  if (net_eq(net, vsock_global_net()))
> +                          fallback = sk_vsock(vsk);
> +          }
>            if (addr->svm_port == vsk->local_addr.svm_port &&
>                (vsk->local_addr.svm_cid == VMADDR_CID_ANY ||
> -               addr->svm_cid == VMADDR_CID_ANY))
> -                  return sk_vsock(vsk);
> +               addr->svm_cid == VMADDR_CID_ANY)) {
> +                  if (net_eq(net, sock_net(sk_vsock(vsk))))
> +                          return sk_vsock(vsk);
> +
> +                  if (net_eq(net, vsock_global_net()))
> +                          fallback = sk_vsock(vsk);
> +          }
>    }
>
> -  return NULL;
> +  return fallback;
> }
>
> static struct sock *__vsock_find_connected_socket(struct sockaddr_vm *src,
> -                                            struct sockaddr_vm *dst)
> +                                            struct sockaddr_vm *dst,
> +                                            struct net *net)
> {
> +  struct sock *fallback = NULL;
>    struct vsock_sock *vsk;
>
>    list_for_each_entry(vsk, vsock_connected_sockets(src, dst),
>                        connected_table) {
>            if (vsock_addr_equals_addr(src, &vsk->remote_addr) &&
>                dst->svm_port == vsk->local_addr.svm_port) {
> -                  return sk_vsock(vsk);
> +                  if (net_eq(net, sock_net(sk_vsock(vsk))))
> +                          return sk_vsock(vsk);
> +
> +                  if (net_eq(net, vsock_global_net()))
> +                          fallback = sk_vsock(vsk);

This pattern seems to be repeated 3 times, can we make a function/macro?


yep, no problem!

>            }
>    }
>
> -  return NULL;
> +  return fallback;
> }
>
> static void vsock_insert_unbound(struct vsock_sock *vsk)
> @@ -304,12 +327,12 @@ void vsock_remove_connected(struct vsock_sock *vsk)
> }
> EXPORT_SYMBOL_GPL(vsock_remove_connected);
>
> -struct sock *vsock_find_bound_socket(struct sockaddr_vm *addr)
> +struct sock *vsock_find_bound_socket(struct sockaddr_vm *addr, struct net 
*net)
> {
>    struct sock *sk;
>
>    spin_lock_bh(&vsock_table_lock);
> -  sk = __vsock_find_bound_socket(addr);
> +  sk = __vsock_find_bound_socket(addr, net);
>    if (sk)
>            sock_hold(sk);
>
> @@ -320,12 +343,13 @@ struct sock *vsock_find_bound_socket(struct sockaddr_vm 
*addr)
> EXPORT_SYMBOL_GPL(vsock_find_bound_socket);
>
> struct sock *vsock_find_connected_socket(struct sockaddr_vm *src,
> -                                   struct sockaddr_vm *dst)
> +                                   struct sockaddr_vm *dst,
> +                                   struct net *net)
> {
>    struct sock *sk;
>
>    spin_lock_bh(&vsock_table_lock);
> -  sk = __vsock_find_connected_socket(src, dst);
> +  sk = __vsock_find_connected_socket(src, dst, net);
>    if (sk)
>            sock_hold(sk);
>
> @@ -644,6 +668,7 @@ static int __vsock_bind_connectible(struct vsock_sock 
*vsk,
> {
>    static u32 port;
>    struct sockaddr_vm new_addr;
> +  struct net *net = sock_net(sk_vsock(vsk));
>
>    if (!port)
>            port = get_random_u32_above(LAST_RESERVED_PORT);
> @@ -660,7 +685,7 @@ static int __vsock_bind_connectible(struct vsock_sock 
*vsk,
>
>                    new_addr.svm_port = port++;
>
> -                  if (!__vsock_find_bound_socket(&new_addr)) {
> +                  if (!__vsock_find_bound_socket(&new_addr, net)) {
>                            found = true;
>                            break;
>                    }
> @@ -677,7 +702,7 @@ static int __vsock_bind_connectible(struct vsock_sock 
*vsk,
>                    return -EACCES;
>            }
>
> -          if (__vsock_find_bound_socket(&new_addr))
> +          if (__vsock_find_bound_socket(&new_addr, net))
>                    return -EADDRINUSE;
>    }
>
> diff --git a/net/vmw_vsock/hyperv_transport.c 
b/net/vmw_vsock/hyperv_transport.c
> index 
31342ab502b4fc35feb812d2c94e0e35ded73771..253609898d24f8a484fcfc3296011c6f501a72a8 
100644
> --- a/net/vmw_vsock/hyperv_transport.c
> +++ b/net/vmw_vsock/hyperv_transport.c
> @@ -313,7 +313,7 @@ static void hvs_open_connection(struct vmbus_channel 
*chan)
>            return;
>
>    hvs_addr_init(&addr, conn_from_host ? if_type : if_instance);
> -  sk = vsock_find_bound_socket(&addr);
> +  sk = vsock_find_bound_socket(&addr, NULL);
>    if (!sk)
>            return;
>
> diff --git a/net/vmw_vsock/virtio_transport_common.c 
b/net/vmw_vsock/virtio_transport_common.c
> index 
7f7de6d8809655fe522749fbbc9025df71f071bd..256d2a4fe482b3cb938a681b6924be69b2065616 
100644
> --- a/net/vmw_vsock/virtio_transport_common.c
> +++ b/net/vmw_vsock/virtio_transport_common.c
> @@ -1590,6 +1590,7 @@ void virtio_transport_recv_pkt(struct virtio_transport 
*t,
>                           struct sk_buff *skb)
> {
>    struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
> +  struct net *net = vsock_global_net();

Why using vsock_global_net() in virtio and directly NULL in the others
transports?


This was an oversight on my part, I found an unnamed NULL harder to
reason about, switched to the func, but forgot to switch over the other
transports.

BTW, I was unsure about just making NULL a macro (e.g.,
VIRTIO_VSOCK_GLOBAL_NET?) instead of a function. I just used a function
because A) I noticed in the prior rev that the default net was a
function instead of some macro to &init_net, and B) the function seemed
a little more flexible for future changes. What are your thoughts here?

Inline function in the header should be fine IMHO.

Thanks,
Stefano



Thanks for the review!

Best,
Bobby



Reply via email to