From: Cong Wang <cong.w...@bytedance.com>

UDP already has udp_sendmsg() which takes lock_sock() inside.
We have to build ->sendmsg_locked() on top of it, by adding
a new parameter for whether the sock has been locked.

Cc: John Fastabend <john.fastab...@gmail.com>
Cc: Daniel Borkmann <dan...@iogearbox.net>
Cc: Jakub Sitnicki <ja...@cloudflare.com>
Cc: Lorenz Bauer <l...@cloudflare.com>
Signed-off-by: Cong Wang <cong.w...@bytedance.com>
---
 include/net/udp.h  |  1 +
 net/ipv4/af_inet.c |  1 +
 net/ipv4/udp.c     | 30 +++++++++++++++++++++++-------
 3 files changed, 25 insertions(+), 7 deletions(-)

diff --git a/include/net/udp.h b/include/net/udp.h
index df7cc1edc200..5264ba1439f9 100644
--- a/include/net/udp.h
+++ b/include/net/udp.h
@@ -292,6 +292,7 @@ int udp_get_port(struct sock *sk, unsigned short snum,
 int udp_err(struct sk_buff *, u32);
 int udp_abort(struct sock *sk, int err);
 int udp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len);
+int udp_sendmsg_locked(struct sock *sk, struct msghdr *msg, size_t len);
 int udp_push_pending_frames(struct sock *sk);
 void udp_flush_pending_frames(struct sock *sk);
 int udp_cmsg_send(struct sock *sk, struct msghdr *msg, u16 *gso_size);
diff --git a/net/ipv4/af_inet.c b/net/ipv4/af_inet.c
index a02ce89b56b5..d8c73a848c53 100644
--- a/net/ipv4/af_inet.c
+++ b/net/ipv4/af_inet.c
@@ -1071,6 +1071,7 @@ const struct proto_ops inet_dgram_ops = {
        .setsockopt        = sock_common_setsockopt,
        .getsockopt        = sock_common_getsockopt,
        .sendmsg           = inet_sendmsg,
+       .sendmsg_locked    = udp_sendmsg_locked,
        .recvmsg           = inet_recvmsg,
        .mmap              = sock_no_mmap,
        .sendpage          = inet_sendpage,
diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c
index dbd25b59ce0e..93db853601d7 100644
--- a/net/ipv4/udp.c
+++ b/net/ipv4/udp.c
@@ -1024,7 +1024,7 @@ int udp_cmsg_send(struct sock *sk, struct msghdr *msg, 
u16 *gso_size)
 }
 EXPORT_SYMBOL_GPL(udp_cmsg_send);
 
-int udp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
+static int __udp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len, bool 
locked)
 {
        struct inet_sock *inet = inet_sk(sk);
        struct udp_sock *up = udp_sk(sk);
@@ -1063,15 +1063,18 @@ int udp_sendmsg(struct sock *sk, struct msghdr *msg, 
size_t len)
                 * There are pending frames.
                 * The socket lock must be held while it's corked.
                 */
-               lock_sock(sk);
+               if (!locked)
+                       lock_sock(sk);
                if (likely(up->pending)) {
                        if (unlikely(up->pending != AF_INET)) {
-                               release_sock(sk);
+                               if (!locked)
+                                       release_sock(sk);
                                return -EINVAL;
                        }
                        goto do_append_data;
                }
-               release_sock(sk);
+               if (!locked)
+                       release_sock(sk);
        }
        ulen += sizeof(struct udphdr);
 
@@ -1241,11 +1244,13 @@ int udp_sendmsg(struct sock *sk, struct msghdr *msg, 
size_t len)
                goto out;
        }
 
-       lock_sock(sk);
+       if (!locked)
+               lock_sock(sk);
        if (unlikely(up->pending)) {
                /* The socket is already corked while preparing it. */
                /* ... which is an evident application bug. --ANK */
-               release_sock(sk);
+               if (!locked)
+                       release_sock(sk);
 
                net_dbg_ratelimited("socket already corked\n");
                err = -EINVAL;
@@ -1272,7 +1277,8 @@ int udp_sendmsg(struct sock *sk, struct msghdr *msg, 
size_t len)
                err = udp_push_pending_frames(sk);
        else if (unlikely(skb_queue_empty(&sk->sk_write_queue)))
                up->pending = 0;
-       release_sock(sk);
+       if (!locked)
+               release_sock(sk);
 
 out:
        ip_rt_put(rt);
@@ -1302,8 +1308,18 @@ int udp_sendmsg(struct sock *sk, struct msghdr *msg, 
size_t len)
        err = 0;
        goto out;
 }
+
+int udp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
+{
+       return __udp_sendmsg(sk, msg, len, false);
+}
 EXPORT_SYMBOL(udp_sendmsg);
 
+int udp_sendmsg_locked(struct sock *sk, struct msghdr *msg, size_t len)
+{
+       return __udp_sendmsg(sk, msg, len, true);
+}
+
 int udp_sendpage(struct sock *sk, struct page *page, int offset,
                 size_t size, int flags)
 {
-- 
2.25.1

Reply via email to