From: Jonathan Lemon <b...@fb.com>

If a sock is marked as sending zc data, have the iterator
retrieve the correct zc pages from the netgpu module.

Signed-off-by: Jonathan Lemon <jonathan.le...@gmail.com>
---
 include/linux/uio.h |  4 ++++
 lib/iov_iter.c      | 53 +++++++++++++++++++++++++++++++++++++++++++++
 net/core/datagram.c |  9 ++++++--
 3 files changed, 64 insertions(+), 2 deletions(-)

diff --git a/include/linux/uio.h b/include/linux/uio.h
index 9576fd8158d7..9d9a68e224b0 100644
--- a/include/linux/uio.h
+++ b/include/linux/uio.h
@@ -227,6 +227,10 @@ ssize_t iov_iter_get_pages(struct iov_iter *i, struct page 
**pages,
 ssize_t iov_iter_get_pages_alloc(struct iov_iter *i, struct page ***pages,
                        size_t maxsize, size_t *start);
 int iov_iter_npages(const struct iov_iter *i, int maxpages);
+struct sock;
+ssize_t iov_iter_sk_get_pages(struct iov_iter *i, struct page **pages,
+                       size_t maxsize, unsigned maxpages, size_t *pgoff,
+                       struct sock *sk);
 
 const void *dup_iter(struct iov_iter *new, struct iov_iter *old, gfp_t flags);
 
diff --git a/lib/iov_iter.c b/lib/iov_iter.c
index bf538c2bec77..69457df64339 100644
--- a/lib/iov_iter.c
+++ b/lib/iov_iter.c
@@ -10,6 +10,9 @@
 #include <linux/scatterlist.h>
 #include <linux/instrumented.h>
 
+#include <net/netgpu.h>
+#include <net/sock.h>
+
 #define PIPE_PARANOIA /* for now */
 
 #define iterate_iovec(i, n, __v, __p, skip, STEP) {    \
@@ -1349,6 +1352,56 @@ ssize_t iov_iter_get_pages(struct iov_iter *i,
 }
 EXPORT_SYMBOL(iov_iter_get_pages);
 
+#if IS_ENABLED(CONFIG_NETGPU)
+ssize_t iov_iter_sk_get_pages(struct iov_iter *i, struct page **pages,
+               size_t maxsize, unsigned maxpages, size_t *pgoff,
+               struct sock *sk)
+{
+       const struct iovec *iov;
+       unsigned long addr;
+       struct iovec v;
+       size_t len;
+       unsigned n;
+       int ret;
+
+       if (!sk->sk_user_data)
+               return iov_iter_get_pages(i, pages, maxsize, maxpages, pgoff);
+
+       if (maxsize > i->count)
+               maxsize = i->count;
+
+       if (!iter_is_iovec(i))
+               return -EFAULT;
+
+       if (iov_iter_rw(i) != WRITE)
+               return -EFAULT;
+
+       iterate_iovec(i, maxsize, v, iov, i->iov_offset, ({
+               addr = (unsigned long)v.iov_base;
+               *pgoff = addr & (PAGE_SIZE - 1);
+               len = v.iov_len + *pgoff;
+
+               if (len > maxpages * PAGE_SIZE)
+                       len = maxpages * PAGE_SIZE;
+
+               n = DIV_ROUND_UP(len, PAGE_SIZE);
+
+               ret = netgpu_get_pages(sk, pages, addr, n);
+               if (ret > 0)
+                       ret = (ret == n ? len : ret * PAGE_SIZE) - *pgoff;
+               return ret;
+       0;}));
+       return 0;
+}
+#else
+ssize_t iov_iter_sk_get_pages(struct iov_iter *i, struct page **pages,
+               size_t maxsize, unsigned maxpages, size_t *pgoff,
+               struct sock *sk)
+{
+       return iov_iter_get_pages(i, pages, maxsize, maxpages, pgoff);
+}
+#endif
+
 static struct page **get_pages_array(size_t n)
 {
        return kvmalloc_array(n, sizeof(struct page *), GFP_KERNEL);
diff --git a/net/core/datagram.c b/net/core/datagram.c
index 639745d4f3b9..d91f14dc56be 100644
--- a/net/core/datagram.c
+++ b/net/core/datagram.c
@@ -530,6 +530,10 @@ int skb_copy_datagram_iter(const struct sk_buff *skb, int 
offset,
                           struct iov_iter *to, int len)
 {
        trace_skb_copy_datagram_iovec(skb, len);
+       if (skb->zc_netgpu) {
+               pr_err("skb netgpu datagram on !netgpu sk\n");
+               return -EFAULT;
+       }
        return __skb_datagram_iter(skb, offset, to, len, false,
                        simple_copy_to_iter, NULL);
 }
@@ -631,8 +635,9 @@ int __zerocopy_sg_from_iter(struct sock *sk, struct sk_buff 
*skb,
                if (frag == MAX_SKB_FRAGS)
                        return -EMSGSIZE;
 
-               copied = iov_iter_get_pages(from, pages, length,
-                                           MAX_SKB_FRAGS - frag, &start);
+               copied = iov_iter_sk_get_pages(from, pages, length,
+                                              MAX_SKB_FRAGS - frag, &start,
+                                              sk);
                if (copied < 0)
                        return -EFAULT;
 
-- 
2.24.1

Reply via email to