Logically it was based on TCP implementation, so to make further
support easier, rewrite it in the TCP way.

Signed-off-by: Arseniy Krasnov <[email protected]>
---
 net/vmw_vsock/virtio_transport_common.c | 64 ++++++++++++-------------
 1 file changed, 32 insertions(+), 32 deletions(-)

diff --git a/net/vmw_vsock/virtio_transport_common.c 
b/net/vmw_vsock/virtio_transport_common.c
index 2fd9eaaf5ca6..00caeeaa5590 100644
--- a/net/vmw_vsock/virtio_transport_common.c
+++ b/net/vmw_vsock/virtio_transport_common.c
@@ -73,10 +73,13 @@ static bool virtio_transport_can_zcopy(const struct 
virtio_transport *t_ops,
 static int virtio_transport_fill_skb(struct sk_buff *skb,
                                     struct virtio_vsock_pkt_info *info,
                                     size_t len,
-                                    bool zcopy)
+                                    bool zcopy, struct ubuf_info *uarg)
 {
        struct msghdr *msg = info->msg;
 
+       /* We have completion - attach it to 'skb'. */
+       skb_zcopy_set(skb, uarg, NULL);
+
        if (zcopy)
                return __zerocopy_sg_from_iter(msg, NULL, skb,
                                               &msg->msg_iter, len, NULL);
@@ -208,7 +211,8 @@ static struct sk_buff *virtio_transport_alloc_skb(struct 
virtio_vsock_pkt_info *
                                                  u32 src_cid,
                                                  u32 src_port,
                                                  u32 dst_cid,
-                                                 u32 dst_port)
+                                                 u32 dst_port,
+                                                 struct ubuf_info *uarg)
 {
        struct vsock_sock *vsk;
        struct sk_buff *skb;
@@ -245,7 +249,7 @@ static struct sk_buff *virtio_transport_alloc_skb(struct 
virtio_vsock_pkt_info *
        if (info->msg && payload_len > 0) {
                int err;
 
-               err = virtio_transport_fill_skb(skb, info, payload_len, zcopy);
+               err = virtio_transport_fill_skb(skb, info, payload_len, zcopy, 
uarg);
                if (err)
                        goto out;
 
@@ -321,38 +325,36 @@ static int virtio_transport_send_pkt_info(struct 
vsock_sock *vsk,
        if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW)
                return pkt_len;
 
-       if (info->msg) {
-               /* If zerocopy is not enabled by 'setsockopt()', we behave as
-                * there is no MSG_ZEROCOPY flag set.
+       if (info->msg && (info->msg->msg_flags & MSG_ZEROCOPY)) {
+               /* If 'info->msg' is not NULL, this is only VIRTIO_VSOCK_OP_RW.
+                * 'MSG_ZEROCOPY' flag handling here is based on the same flag
+                * handling from 'tcp_sendmsg_locked()'.
                 */
-               if (!sock_flag(sk_vsock(vsk), SOCK_ZEROCOPY))
-                       info->msg->msg_flags &= ~MSG_ZEROCOPY;
+               if (info->msg->msg_ubuf) {
+                       uarg = info->msg->msg_ubuf;
+                       can_zcopy = virtio_transport_can_zcopy(t_ops, info, 
pkt_len);
+               } else if (sock_flag(sk_vsock(vsk), SOCK_ZEROCOPY)) {
+                       uarg = msg_zerocopy_realloc(sk_vsock(vsk), pkt_len,
+                                                   NULL, false);
+                       if (!uarg) {
+                               virtio_transport_put_credit(vvs, pkt_len);
+                               return -ENOMEM;
+                       }
 
-               if (info->msg->msg_flags & MSG_ZEROCOPY)
                        can_zcopy = virtio_transport_can_zcopy(t_ops, info, 
pkt_len);
 
+                       if (!can_zcopy)
+                               uarg_to_msgzc(uarg)->zerocopy = 0;
+
+                       have_uref = true;
+               }
+
+               /* 'can_zcopy' means that this transmission will be
+                * in zerocopy way (e.g. using 'frags' array).
+                */
                if (can_zcopy)
                        max_skb_len = min_t(u32, VIRTIO_VSOCK_MAX_PKT_BUF_SIZE,
                                            (MAX_SKB_FRAGS * PAGE_SIZE));
-
-               if (info->msg->msg_flags & MSG_ZEROCOPY &&
-                   info->op == VIRTIO_VSOCK_OP_RW) {
-                       uarg = info->msg->msg_ubuf;
-
-                       if (!uarg) {
-                               uarg = msg_zerocopy_realloc(sk_vsock(vsk),
-                                                           pkt_len, NULL, 
false);
-                               if (!uarg) {
-                                       virtio_transport_put_credit(vvs, 
pkt_len);
-                                       return -ENOMEM;
-                               }
-
-                               if (!can_zcopy)
-                                       uarg_to_msgzc(uarg)->zerocopy = 0;
-
-                               have_uref = true;
-                       }
-               }
        }
 
        rest_len = pkt_len;
@@ -365,14 +367,12 @@ static int virtio_transport_send_pkt_info(struct 
vsock_sock *vsk,
 
                skb = virtio_transport_alloc_skb(info, skb_len, can_zcopy,
                                                 src_cid, src_port,
-                                                dst_cid, dst_port);
+                                                dst_cid, dst_port, uarg);
                if (!skb) {
                        ret = -ENOMEM;
                        break;
                }
 
-               skb_zcopy_set(skb, uarg, NULL);
-
                virtio_transport_inc_tx_pkt(vvs, skb);
 
                ret = t_ops->send_pkt(skb, info->net);
@@ -1178,7 +1178,7 @@ static int virtio_transport_reset_no_sock(const struct 
virtio_transport *t,
                                           le64_to_cpu(hdr->dst_cid),
                                           le32_to_cpu(hdr->dst_port),
                                           le64_to_cpu(hdr->src_cid),
-                                          le32_to_cpu(hdr->src_port));
+                                          le32_to_cpu(hdr->src_port), NULL);
        if (!reply)
                return -ENOMEM;
 
-- 
2.25.1


Reply via email to