From: Cong Wang <cong.w...@bytedance.com>

Similarly, udpv6_sendmsg() takes lock_sock() inside too,
we have to build ->sendmsg_locked() on top of it.

For ->read_sock(), we can just use udp_read_sock().

Cc: John Fastabend <john.fastab...@gmail.com>
Cc: Daniel Borkmann <dan...@iogearbox.net>
Cc: Jakub Sitnicki <ja...@cloudflare.com>
Cc: Lorenz Bauer <l...@cloudflare.com>
Signed-off-by: Cong Wang <cong.w...@bytedance.com>
---
 include/net/ipv6.h  |  1 +
 net/ipv4/udp.c      |  1 +
 net/ipv6/af_inet6.c |  2 ++
 net/ipv6/udp.c      | 27 +++++++++++++++++++++------
 4 files changed, 25 insertions(+), 6 deletions(-)

diff --git a/include/net/ipv6.h b/include/net/ipv6.h
index bd1f396cc9c7..48b6850dae85 100644
--- a/include/net/ipv6.h
+++ b/include/net/ipv6.h
@@ -1119,6 +1119,7 @@ int inet6_hash_connect(struct inet_timewait_death_row 
*death_row,
 int inet6_sendmsg(struct socket *sock, struct msghdr *msg, size_t size);
 int inet6_recvmsg(struct socket *sock, struct msghdr *msg, size_t size,
                  int flags);
+int udpv6_sendmsg_locked(struct sock *sk, struct msghdr *msg, size_t len);
 
 /*
  * reassembly.c
diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c
index fd8f27ee5b4e..6658db231475 100644
--- a/net/ipv4/udp.c
+++ b/net/ipv4/udp.c
@@ -1831,6 +1831,7 @@ int udp_read_sock(struct sock *sk, read_descriptor_t 
*desc,
 
        return copied;
 }
+EXPORT_SYMBOL(udp_read_sock);
 
 /*
  *     This should be easy, if there is something there we
diff --git a/net/ipv6/af_inet6.c b/net/ipv6/af_inet6.c
index 1fb75f01756c..634ab3a825d7 100644
--- a/net/ipv6/af_inet6.c
+++ b/net/ipv6/af_inet6.c
@@ -714,7 +714,9 @@ const struct proto_ops inet6_dgram_ops = {
        .setsockopt        = sock_common_setsockopt,    /* ok           */
        .getsockopt        = sock_common_getsockopt,    /* ok           */
        .sendmsg           = inet6_sendmsg,             /* retpoline's sake */
+       .sendmsg_locked    = udpv6_sendmsg_locked,
        .recvmsg           = inet6_recvmsg,             /* retpoline's sake */
+       .read_sock         = udp_read_sock,
        .mmap              = sock_no_mmap,
        .sendpage          = sock_no_sendpage,
        .set_peek_off      = sk_set_peek_off,
diff --git a/net/ipv6/udp.c b/net/ipv6/udp.c
index ef2c75bb4771..124a316da410 100644
--- a/net/ipv6/udp.c
+++ b/net/ipv6/udp.c
@@ -1272,7 +1272,7 @@ static int udp_v6_push_pending_frames(struct sock *sk)
        return err;
 }
 
-int udpv6_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
+static int __udpv6_sendmsg(struct sock *sk, struct msghdr *msg, size_t len, 
bool locked)
 {
        struct ipv6_txoptions opt_space;
        struct udp_sock *up = udp_sk(sk);
@@ -1361,7 +1361,8 @@ int udpv6_sendmsg(struct sock *sk, struct msghdr *msg, 
size_t len)
                 * There are pending frames.
                 * The socket lock must be held while it's corked.
                 */
-               lock_sock(sk);
+               if (!locked)
+                       lock_sock(sk);
                if (likely(up->pending)) {
                        if (unlikely(up->pending != AF_INET6)) {
                                release_sock(sk);
@@ -1370,7 +1371,8 @@ int udpv6_sendmsg(struct sock *sk, struct msghdr *msg, 
size_t len)
                        dst = NULL;
                        goto do_append_data;
                }
-               release_sock(sk);
+               if (!locked)
+                       release_sock(sk);
        }
        ulen += sizeof(struct udphdr);
 
@@ -1533,11 +1535,13 @@ int udpv6_sendmsg(struct sock *sk, struct msghdr *msg, 
size_t len)
                goto out;
        }
 
-       lock_sock(sk);
+       if (!locked)
+               lock_sock(sk);
        if (unlikely(up->pending)) {
                /* The socket is already corked while preparing it. */
                /* ... which is an evident application bug. --ANK */
-               release_sock(sk);
+               if (!locked)
+                       release_sock(sk);
 
                net_dbg_ratelimited("udp cork app bug 2\n");
                err = -EINVAL;
@@ -1562,7 +1566,8 @@ int udpv6_sendmsg(struct sock *sk, struct msghdr *msg, 
size_t len)
 
        if (err > 0)
                err = np->recverr ? net_xmit_errno(err) : 0;
-       release_sock(sk);
+       if (!locked)
+               release_sock(sk);
 
 out:
        dst_release(dst);
@@ -1593,6 +1598,16 @@ int udpv6_sendmsg(struct sock *sk, struct msghdr *msg, 
size_t len)
        goto out;
 }
 
+int udpv6_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
+{
+       return __udpv6_sendmsg(sk, msg, len, false);
+}
+
+int udpv6_sendmsg_locked(struct sock *sk, struct msghdr *msg, size_t len)
+{
+       return __udpv6_sendmsg(sk, msg, len, true);
+}
+
 void udpv6_destroy_sock(struct sock *sk)
 {
        struct udp_sock *up = udp_sk(sk);
-- 
2.25.1

Reply via email to