Implement splice_read for sockmap using an always-copy approach.
Each page from the psock ingress scatterlist is copied to a newly
allocated page before being added to the pipe, avoiding lifetime
and slab-page issues.

Add sk_msg_splice_actor() which allocates a fresh page via
alloc_page(), copies the data with memcpy(), then passes it to
add_to_pipe(). The newly allocated page already has a refcount
of 1, so no additional get_page() is needed. On add_to_pipe()
failure, no explicit cleanup is needed since add_to_pipe()
internally calls pipe_buf_release().

Also fix sk_msg_read_core() to update msg_rx->sg.start when the
actor returns 0 mid-way through processing. The loop processes
msg_rx->sg entries sequentially — if the actor fails (e.g. pipe
full for splice, or user buffer fault for recvmsg), prior entries
may already be consumed with sge->length set to 0. Without
advancing sg.start, subsequent calls would revisit these
zero-length entries and return -EFAULT. This is especially
common with the splice actor since the pipe has a small fixed
capacity (16 slots), but theoretically affects recvmsg as well.

Signed-off-by: Jiayuan Chen <[email protected]>
---
 net/core/skmsg.c   | 10 ++++++
 net/ipv4/tcp_bpf.c | 83 ++++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 93 insertions(+)

diff --git a/net/core/skmsg.c b/net/core/skmsg.c
index 6a906bfe3aa4..2fcbf8eaf4cf 100644
--- a/net/core/skmsg.c
+++ b/net/core/skmsg.c
@@ -445,6 +445,16 @@ int sk_msg_read_core(struct sock *sk, struct sk_psock 
*psock,
                                copy = actor(actor_arg, page,
                                             sge->offset, copy);
                        if (!copy) {
+                               /*
+                                * The loop processes msg_rx->sg entries
+                                * sequentially and prior entries may
+                                * already be consumed. Advance sg.start
+                                * so the next call resumes at the correct
+                                * entry, otherwise it would revisit
+                                * zero-length entries and return -EFAULT.
+                                */
+                               if (!peek)
+                                       msg_rx->sg.start = i;
                                copied = copied ? copied : -EFAULT;
                                goto out;
                        }
diff --git a/net/ipv4/tcp_bpf.c b/net/ipv4/tcp_bpf.c
index 606c2b079f86..e85a27e32ea7 100644
--- a/net/ipv4/tcp_bpf.c
+++ b/net/ipv4/tcp_bpf.c
@@ -7,6 +7,7 @@
 #include <linux/init.h>
 #include <linux/wait.h>
 #include <linux/util_macros.h>
+#include <linux/splice.h>
 
 #include <net/inet_common.h>
 #include <net/tls.h>
@@ -444,6 +445,85 @@ static int tcp_bpf_recvmsg(struct sock *sk, struct msghdr 
*msg, size_t len,
        return ret;
 }
 
+struct tcp_bpf_splice_ctx {
+       struct pipe_inode_info *pipe;
+};
+
+static int sk_msg_splice_actor(void *arg, struct page *page,
+                              unsigned int offset, size_t len)
+{
+       struct tcp_bpf_splice_ctx *ctx = arg;
+       struct pipe_buffer buf = {
+               .ops = &nosteal_pipe_buf_ops,
+       };
+       ssize_t ret;
+
+       buf.page = alloc_page(GFP_KERNEL);
+       if (!buf.page)
+               return 0;
+
+       memcpy(page_address(buf.page), page_address(page) + offset, len);
+       buf.offset = 0;
+       buf.len = len;
+
+       /*
+        * add_to_pipe() calls pipe_buf_release() on failure, which
+        * handles put_page() via nosteal_pipe_buf_ops, so no explicit
+        * cleanup is needed here.
+        */
+       ret = add_to_pipe(ctx->pipe, &buf);
+       if (ret <= 0)
+               return 0;
+       return ret;
+}
+
+static ssize_t tcp_bpf_splice_read(struct socket *sock, loff_t *ppos,
+                                  struct pipe_inode_info *pipe, size_t len,
+                                  unsigned int flags)
+{
+       struct tcp_bpf_splice_ctx ctx = { .pipe = pipe };
+       int bpf_flags = flags & SPLICE_F_NONBLOCK ? MSG_DONTWAIT : 0;
+       struct sock *sk = sock->sk;
+       struct sk_psock *psock;
+       int ret;
+
+       psock = sk_psock_get(sk);
+       if (unlikely(!psock))
+               return tcp_splice_read(sock, ppos, pipe, len, flags);
+       if (!skb_queue_empty(&sk->sk_receive_queue) &&
+           sk_psock_queue_empty(psock)) {
+               sk_psock_put(sk, psock);
+               return tcp_splice_read(sock, ppos, pipe, len, flags);
+       }
+
+       ret = __tcp_bpf_recvmsg(sk, psock, sk_msg_splice_actor, &ctx,
+                               len, bpf_flags);
+       sk_psock_put(sk, psock);
+       if (!ret)
+               return tcp_splice_read(sock, ppos, pipe, len, flags);
+       return ret;
+}
+
+static ssize_t tcp_bpf_splice_read_parser(struct socket *sock, loff_t *ppos,
+                                         struct pipe_inode_info *pipe,
+                                         size_t len, unsigned int flags)
+{
+       struct tcp_bpf_splice_ctx ctx = { .pipe = pipe };
+       int bpf_flags = flags & SPLICE_F_NONBLOCK ? MSG_DONTWAIT : 0;
+       struct sock *sk = sock->sk;
+       struct sk_psock *psock;
+       int ret;
+
+       psock = sk_psock_get(sk);
+       if (unlikely(!psock))
+               return tcp_splice_read(sock, ppos, pipe, len, flags);
+
+       ret = __tcp_bpf_recvmsg_parser(sk, psock, sk_msg_splice_actor, &ctx,
+                                      len, bpf_flags);
+       sk_psock_put(sk, psock);
+       return ret;
+}
+
 static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock,
                                struct sk_msg *msg, int *copied, int flags)
 {
@@ -671,6 +751,7 @@ static void tcp_bpf_rebuild_protos(struct proto 
prot[TCP_BPF_NUM_CFGS],
        prot[TCP_BPF_BASE].destroy              = sock_map_destroy;
        prot[TCP_BPF_BASE].close                = sock_map_close;
        prot[TCP_BPF_BASE].recvmsg              = tcp_bpf_recvmsg;
+       prot[TCP_BPF_BASE].splice_read          = tcp_bpf_splice_read;
        prot[TCP_BPF_BASE].sock_is_readable     = sk_msg_is_readable;
        prot[TCP_BPF_BASE].ioctl                = tcp_bpf_ioctl;
 
@@ -679,9 +760,11 @@ static void tcp_bpf_rebuild_protos(struct proto 
prot[TCP_BPF_NUM_CFGS],
 
        prot[TCP_BPF_RX]                        = prot[TCP_BPF_BASE];
        prot[TCP_BPF_RX].recvmsg                = tcp_bpf_recvmsg_parser;
+       prot[TCP_BPF_RX].splice_read            = tcp_bpf_splice_read_parser;
 
        prot[TCP_BPF_TXRX]                      = prot[TCP_BPF_TX];
        prot[TCP_BPF_TXRX].recvmsg              = tcp_bpf_recvmsg_parser;
+       prot[TCP_BPF_TXRX].splice_read          = tcp_bpf_splice_read_parser;
 }
 
 static void tcp_bpf_check_v6_needs_rebuild(struct proto *ops)
-- 
2.43.0


Reply via email to