The branch main has been updated by markj:

URL: 
https://cgit.FreeBSD.org/src/commit/?id=caccbaef8e263b1d769e7bcac1c4617bdc12d484

commit caccbaef8e263b1d769e7bcac1c4617bdc12d484
Author:     Mark Johnston <ma...@freebsd.org>
AuthorDate: 2025-02-06 14:16:21 +0000
Commit:     Mark Johnston <ma...@freebsd.org>
CommitDate: 2025-02-06 14:16:21 +0000

    socket: Move SO_SETFIB handling to protocol layers
    
    In particular, we store a FIB number in both struct socket and in struct
    inpcb.  When updating the FIB number with setsockopt(SO_SETFIB), make
    the update atomic.  This is required to support the new bind_all_fibs
    mode, since in that mode changing the FIB of a bound socket is not
    permitted.
    
    This requires a bit more code, but avoids a layering violation in
    sosetopt(), where we hard-code the list of protocol families that
    implement SO_SETFIB.
    
    Reviewed by:    glebius
    MFC after:      2 weeks
    Sponsored by:   Klara, Inc.
    Sponsored by:   Stormshield
    Differential Revision:  https://reviews.freebsd.org/D48666
---
 sys/kern/uipc_socket.c    | 29 ++++++++++++++---------------
 sys/net/rtsock.c          | 25 +++++++++++++++++++++++++
 sys/netinet/ip_output.c   | 16 ++++++++++++++--
 sys/netinet/raw_ip.c      | 11 ++++-------
 sys/netinet6/ip6_output.c | 16 ++++++++++++++--
 sys/netinet6/raw_ip6.c    | 11 ++++-------
 sys/sys/socketvar.h       |  1 +
 7 files changed, 76 insertions(+), 33 deletions(-)

diff --git a/sys/kern/uipc_socket.c b/sys/kern/uipc_socket.c
index c4ebb43eef18..65cea2e067cf 100644
--- a/sys/kern/uipc_socket.c
+++ b/sys/kern/uipc_socket.c
@@ -3699,6 +3699,19 @@ sorflush(struct socket *so)
 
 }
 
+int
+sosetfib(struct socket *so, int fibnum)
+{
+       if (fibnum < 0 || fibnum >= rt_numfibs)
+               return (EINVAL);
+
+       SOCK_LOCK(so);
+       so->so_fibnum = fibnum;
+       SOCK_UNLOCK(so);
+
+       return (0);
+}
+
 #ifdef SOCKET_HHOOK
 /*
  * Wrapper for Socket established helper hook.
@@ -3847,21 +3860,7 @@ sosetopt(struct socket *so, struct sockopt *sopt)
                        break;
 
                case SO_SETFIB:
-                       error = sooptcopyin(sopt, &optval, sizeof optval,
-                           sizeof optval);
-                       if (error)
-                               goto bad;
-
-                       if (optval < 0 || optval >= rt_numfibs) {
-                               error = EINVAL;
-                               goto bad;
-                       }
-                       if (((so->so_proto->pr_domain->dom_family == PF_INET) ||
-                          (so->so_proto->pr_domain->dom_family == PF_INET6) ||
-                          (so->so_proto->pr_domain->dom_family == PF_ROUTE)))
-                               so->so_fibnum = optval;
-                       else
-                               so->so_fibnum = 0;
+                       error = so->so_proto->pr_ctloutput(so, sopt);
                        break;
 
                case SO_USER_COOKIE:
diff --git a/sys/net/rtsock.c b/sys/net/rtsock.c
index ce5ec9ce22af..f0dcc973ca7c 100644
--- a/sys/net/rtsock.c
+++ b/sys/net/rtsock.c
@@ -423,6 +423,30 @@ rts_attach(struct socket *so, int proto, struct thread *td)
        return (0);
 }
 
+static int
+rts_ctloutput(struct socket *so, struct sockopt *sopt)
+{
+       int error, optval;
+
+       error = ENOPROTOOPT;
+       if (sopt->sopt_dir == SOPT_SET) {
+               switch (sopt->sopt_level) {
+               case SOL_SOCKET:
+                       switch (sopt->sopt_name) {
+                       case SO_SETFIB:
+                               error = sooptcopyin(sopt, &optval,
+                                   sizeof(optval), sizeof(optval));
+                               if (error != 0)
+                                       break;
+                               error = sosetfib(so, optval);
+                               break;
+                       }
+                       break;
+               }
+       }
+       return (error);
+}
+
 static void
 rts_detach(struct socket *so)
 {
@@ -2702,6 +2726,7 @@ static struct protosw routesw = {
        .pr_flags =             PR_ATOMIC|PR_ADDR,
        .pr_abort =             rts_close,
        .pr_attach =            rts_attach,
+       .pr_ctloutput =         rts_ctloutput,
        .pr_detach =            rts_detach,
        .pr_send =              rts_send,
        .pr_shutdown =          rts_shutdown,
diff --git a/sys/netinet/ip_output.c b/sys/netinet/ip_output.c
index d0dbd22512f0..a83400e90c35 100644
--- a/sys/netinet/ip_output.c
+++ b/sys/netinet/ip_output.c
@@ -1094,10 +1094,22 @@ ip_ctloutput(struct socket *so, struct sockopt *sopt)
                    sopt->sopt_dir == SOPT_SET) {
                        switch (sopt->sopt_name) {
                        case SO_SETFIB:
+                               error = sooptcopyin(sopt, &optval,
+                                   sizeof(optval), sizeof(optval));
+                               if (error != 0)
+                                       break;
+
                                INP_WLOCK(inp);
-                               inp->inp_inc.inc_fibnum = so->so_fibnum;
+                               if ((inp->inp_flags & INP_BOUNDFIB) != 0 &&
+                                   optval != so->so_fibnum) {
+                                       INP_WUNLOCK(inp);
+                                       error = EISCONN;
+                                       break;
+                               }
+                               error = sosetfib(inp->inp_socket, optval);
+                               if (error == 0)
+                                       inp->inp_inc.inc_fibnum = optval;
                                INP_WUNLOCK(inp);
-                               error = 0;
                                break;
                        case SO_MAX_PACING_RATE:
 #ifdef RATELIMIT
diff --git a/sys/netinet/raw_ip.c b/sys/netinet/raw_ip.c
index 3a0b9f632fb4..5f6acbaa5e6f 100644
--- a/sys/netinet/raw_ip.c
+++ b/sys/netinet/raw_ip.c
@@ -633,13 +633,10 @@ rip_ctloutput(struct socket *so, struct sockopt *sopt)
        int     error, optval;
 
        if (sopt->sopt_level != IPPROTO_IP) {
-               if ((sopt->sopt_level == SOL_SOCKET) &&
-                   (sopt->sopt_name == SO_SETFIB)) {
-                       INP_WLOCK(inp);
-                       inp->inp_inc.inc_fibnum = so->so_fibnum;
-                       INP_WUNLOCK(inp);
-                       return (0);
-               }
+               if (sopt->sopt_dir == SOPT_SET &&
+                   sopt->sopt_level == SOL_SOCKET &&
+                   sopt->sopt_name == SO_SETFIB)
+                       return (ip_ctloutput(so, sopt));
                return (EINVAL);
        }
 
diff --git a/sys/netinet6/ip6_output.c b/sys/netinet6/ip6_output.c
index c48101aa2990..ed71c58fffbe 100644
--- a/sys/netinet6/ip6_output.c
+++ b/sys/netinet6/ip6_output.c
@@ -1648,10 +1648,22 @@ ip6_ctloutput(struct socket *so, struct sockopt *sopt)
                    sopt->sopt_dir == SOPT_SET) {
                        switch (sopt->sopt_name) {
                        case SO_SETFIB:
+                               error = sooptcopyin(sopt, &optval,
+                                   sizeof(optval), sizeof(optval));
+                               if (error != 0)
+                                       break;
+
                                INP_WLOCK(inp);
-                               inp->inp_inc.inc_fibnum = so->so_fibnum;
+                               if ((inp->inp_flags & INP_BOUNDFIB) != 0 &&
+                                   optval != so->so_fibnum) {
+                                       INP_WUNLOCK(inp);
+                                       error = EISCONN;
+                                       break;
+                               }
+                               error = sosetfib(inp->inp_socket, optval);
+                               if (error == 0)
+                                       inp->inp_inc.inc_fibnum = optval;
                                INP_WUNLOCK(inp);
-                               error = 0;
                                break;
                        case SO_MAX_PACING_RATE:
 #ifdef RATELIMIT
diff --git a/sys/netinet6/raw_ip6.c b/sys/netinet6/raw_ip6.c
index 8f1955164928..b761dc422feb 100644
--- a/sys/netinet6/raw_ip6.c
+++ b/sys/netinet6/raw_ip6.c
@@ -576,13 +576,10 @@ rip6_ctloutput(struct socket *so, struct sockopt *sopt)
                 */
                return (icmp6_ctloutput(so, sopt));
        else if (sopt->sopt_level != IPPROTO_IPV6) {
-               if (sopt->sopt_level == SOL_SOCKET &&
-                   sopt->sopt_name == SO_SETFIB) {
-                       INP_WLOCK(inp);
-                       inp->inp_inc.inc_fibnum = so->so_fibnum;
-                       INP_WUNLOCK(inp);
-                       return (0);
-               }
+               if (sopt->sopt_dir == SOPT_SET &&
+                   sopt->sopt_level == SOL_SOCKET &&
+                   sopt->sopt_name == SO_SETFIB)
+                       return (ip6_ctloutput(so, sopt));
                return (EINVAL);
        }
 
diff --git a/sys/sys/socketvar.h b/sys/sys/socketvar.h
index 6e2eb64ea0b8..bfbb5cbd37fd 100644
--- a/sys/sys/socketvar.h
+++ b/sys/sys/socketvar.h
@@ -552,6 +552,7 @@ int sosend_dgram(struct socket *so, struct sockaddr *addr,
 int    sosend_generic(struct socket *so, struct sockaddr *addr,
            struct uio *uio, struct mbuf *top, struct mbuf *control,
            int flags, struct thread *td);
+int    sosetfib(struct socket *so, int fibnum);
 int    soshutdown(struct socket *so, enum shutdown_how);
 void   soupcall_clear(struct socket *, sb_which);
 void   soupcall_set(struct socket *, sb_which, so_upcall_t, void *);

Reply via email to