The core vhost-user protocol code should not do socket I/O because the
details are transport-specific.  Move code to send and receive
vhost-user messages into trans_af_unix.c.

Signed-off-by: Stefan Hajnoczi <stefa...@redhat.com>
---
 lib/librte_vhost/vhost.h         | 26 ++++++++++++
 lib/librte_vhost/vhost_user.h    |  6 +--
 lib/librte_vhost/trans_af_unix.c | 70 +++++++++++++++++++++++++++++++--
 lib/librte_vhost/vhost_user.c    | 85 +++++++++++-----------------------------
 4 files changed, 115 insertions(+), 72 deletions(-)

diff --git a/lib/librte_vhost/vhost.h b/lib/librte_vhost/vhost.h
index 757e18391..ac9ceefb9 100644
--- a/lib/librte_vhost/vhost.h
+++ b/lib/librte_vhost/vhost.h
@@ -203,6 +203,7 @@ struct guest_page {
 
 struct virtio_net;
 struct vhost_user_socket;
+struct VhostUserMsg;
 
 /**
  * A structure containing function pointers for transport-specific operations.
@@ -264,6 +265,31 @@ struct vhost_transport_ops {
         *  0 on success, -1 on failure
         */
        int (*vring_call)(struct virtio_net *dev, struct vhost_virtqueue *vq);
+
+       /**
+        * Send a reply to the master.
+        *
+        * @param dev
+        *  vhost device
+        * @param reply
+        *  reply message
+        * @return
+        *  0 on success, -1 on failure
+        */
+       int (*send_reply)(struct virtio_net *dev, struct VhostUserMsg *reply);
+
+       /**
+        * Send a slave request to the master.
+        *
+        * @param dev
+        *  vhost device
+        * @param req
+        *  request message
+        * @return
+        *  0 on success, -1 on failure
+        */
+       int (*send_slave_req)(struct virtio_net *dev,
+                             struct VhostUserMsg *req);
 };
 
 /** The traditional AF_UNIX vhost-user protocol transport. */
diff --git a/lib/librte_vhost/vhost_user.h b/lib/librte_vhost/vhost_user.h
index d4bd604b9..dec658dff 100644
--- a/lib/librte_vhost/vhost_user.h
+++ b/lib/librte_vhost/vhost_user.h
@@ -110,11 +110,7 @@ typedef struct VhostUserMsg {
 
 
 /* vhost_user.c */
-int vhost_user_msg_handler(int vid, int fd);
+int vhost_user_msg_handler(int vid, const struct VhostUserMsg *msg);
 int vhost_user_iotlb_miss(struct virtio_net *dev, uint64_t iova, uint8_t perm);
 
-/* socket.c */
-int read_fd_message(int sockfd, char *buf, int buflen, int *fds, int fd_num);
-int send_fd_message(int sockfd, char *buf, int buflen, int *fds, int fd_num);
-
 #endif
diff --git a/lib/librte_vhost/trans_af_unix.c b/lib/librte_vhost/trans_af_unix.c
index dde3e76cd..9e5a5c127 100644
--- a/lib/librte_vhost/trans_af_unix.c
+++ b/lib/librte_vhost/trans_af_unix.c
@@ -75,7 +75,7 @@ static int vhost_user_start_client(struct vhost_user_socket 
*vsocket);
 static void vhost_user_read_cb(int connfd, void *dat, int *remove);
 
 /* return bytes# of read on success or negative val on failure. */
-int
+static int
 read_fd_message(int sockfd, char *buf, int buflen, int *fds, int fd_num)
 {
        struct iovec iov;
@@ -117,8 +117,8 @@ read_fd_message(int sockfd, char *buf, int buflen, int 
*fds, int fd_num)
        return ret;
 }
 
-int
-send_fd_message(int sockfd, char *buf, int buflen, int *fds, int fd_num)
+static int
+send_fd_message(int sockfd, void *buf, int buflen, int *fds, int fd_num)
 {
 
        struct iovec iov;
@@ -160,6 +160,23 @@ send_fd_message(int sockfd, char *buf, int buflen, int 
*fds, int fd_num)
        return ret;
 }
 
+static int
+af_unix_send_reply(struct virtio_net *dev, struct VhostUserMsg *msg)
+{
+       struct vhost_user_connection *conn =
+               container_of(dev, struct vhost_user_connection, device);
+
+       return send_fd_message(conn->connfd, msg,
+                              VHOST_USER_HDR_SIZE + msg->size, NULL, 0);
+}
+
+static int
+af_unix_send_slave_req(struct virtio_net *dev, struct VhostUserMsg *msg)
+{
+       return send_fd_message(dev->slave_req_fd, msg,
+                              VHOST_USER_HDR_SIZE + msg->size, NULL, 0);
+}
+
 static void
 vhost_user_add_connection(int fd, struct vhost_user_socket *vsocket)
 {
@@ -234,6 +251,36 @@ vhost_user_server_new_connection(int fd, void *dat, int 
*remove __rte_unused)
        vhost_user_add_connection(fd, vsocket);
 }
 
+/* return bytes# of read on success or negative val on failure. */
+static int
+read_vhost_message(int sockfd, struct VhostUserMsg *msg)
+{
+       int ret;
+
+       ret = read_fd_message(sockfd, (char *)msg, VHOST_USER_HDR_SIZE,
+               msg->fds, VHOST_MEMORY_MAX_NREGIONS);
+       if (ret <= 0)
+               return ret;
+
+       if (msg && msg->size) {
+               if (msg->size > sizeof(msg->payload)) {
+                       RTE_LOG(ERR, VHOST_CONFIG,
+                               "invalid msg size: %d\n", msg->size);
+                       return -1;
+               }
+               ret = read(sockfd, &msg->payload, msg->size);
+               if (ret <= 0)
+                       return ret;
+               if (ret != (int)msg->size) {
+                       RTE_LOG(ERR, VHOST_CONFIG,
+                               "read control message failed\n");
+                       return -1;
+               }
+       }
+
+       return ret;
+}
+
 static void
 vhost_user_read_cb(int connfd, void *dat, int *remove)
 {
@@ -241,10 +288,23 @@ vhost_user_read_cb(int connfd, void *dat, int *remove)
        struct vhost_user_socket *vsocket = conn->vsocket;
        struct af_unix_socket *s =
                container_of(vsocket, struct af_unix_socket, socket);
+       struct VhostUserMsg msg;
        int ret;
 
-       ret = vhost_user_msg_handler(conn->device.vid, connfd);
+       ret = read_vhost_message(connfd, &msg);
+       if (ret <= 0) {
+               if (ret < 0)
+                       RTE_LOG(ERR, VHOST_CONFIG,
+                               "vhost read message failed\n");
+               else if (ret == 0)
+                       RTE_LOG(INFO, VHOST_CONFIG,
+                               "vhost peer closed\n");
+               goto err;
+       }
+
+       ret = vhost_user_msg_handler(conn->device.vid, &msg);
        if (ret < 0) {
+err:
                close(connfd);
                *remove = 1;
 
@@ -613,4 +673,6 @@ const struct vhost_transport_ops af_unix_trans_ops = {
        .socket_cleanup = af_unix_socket_cleanup,
        .socket_start = af_unix_socket_start,
        .vring_call = af_unix_vring_call,
+       .send_reply = af_unix_send_reply,
+       .send_slave_req = af_unix_send_slave_req,
 };
diff --git a/lib/librte_vhost/vhost_user.c b/lib/librte_vhost/vhost_user.c
index e54795a41..5f89453bc 100644
--- a/lib/librte_vhost/vhost_user.c
+++ b/lib/librte_vhost/vhost_user.c
@@ -1137,48 +1137,8 @@ vhost_user_iotlb_msg(struct virtio_net **pdev, struct 
VhostUserMsg *msg)
        return 0;
 }
 
-/* return bytes# of read on success or negative val on failure. */
 static int
-read_vhost_message(int sockfd, struct VhostUserMsg *msg)
-{
-       int ret;
-
-       ret = read_fd_message(sockfd, (char *)msg, VHOST_USER_HDR_SIZE,
-               msg->fds, VHOST_MEMORY_MAX_NREGIONS);
-       if (ret <= 0)
-               return ret;
-
-       if (msg && msg->size) {
-               if (msg->size > sizeof(msg->payload)) {
-                       RTE_LOG(ERR, VHOST_CONFIG,
-                               "invalid msg size: %d\n", msg->size);
-                       return -1;
-               }
-               ret = read(sockfd, &msg->payload, msg->size);
-               if (ret <= 0)
-                       return ret;
-               if (ret != (int)msg->size) {
-                       RTE_LOG(ERR, VHOST_CONFIG,
-                               "read control message failed\n");
-                       return -1;
-               }
-       }
-
-       return ret;
-}
-
-static int
-send_vhost_message(int sockfd, struct VhostUserMsg *msg)
-{
-       if (!msg)
-               return 0;
-
-       return send_fd_message(sockfd, (char *)msg,
-               VHOST_USER_HDR_SIZE + msg->size, NULL, 0);
-}
-
-static int
-send_vhost_reply(int sockfd, struct VhostUserMsg *msg)
+send_vhost_reply(struct virtio_net *dev, struct VhostUserMsg *msg)
 {
        if (!msg)
                return 0;
@@ -1188,7 +1148,16 @@ send_vhost_reply(int sockfd, struct VhostUserMsg *msg)
        msg->flags |= VHOST_USER_VERSION;
        msg->flags |= VHOST_USER_REPLY_MASK;
 
-       return send_vhost_message(sockfd, msg);
+       return dev->trans_ops->send_reply(dev, msg);
+}
+
+static int
+send_vhost_slave_req(struct virtio_net *dev, struct VhostUserMsg *msg)
+{
+       if (!msg)
+               return 0;
+
+       return dev->trans_ops->send_slave_req(dev, msg);
 }
 
 /*
@@ -1230,10 +1199,10 @@ vhost_user_check_and_alloc_queue_pair(struct virtio_net 
*dev, VhostUserMsg *msg)
 }
 
 int
-vhost_user_msg_handler(int vid, int fd)
+vhost_user_msg_handler(int vid, const struct VhostUserMsg *msg_)
 {
+       struct VhostUserMsg msg = *msg_; /* copy so we can build the reply */
        struct virtio_net *dev;
-       struct VhostUserMsg msg;
        int ret;
 
        dev = get_device(vid);
@@ -1250,18 +1219,8 @@ vhost_user_msg_handler(int vid, int fd)
                }
        }
 
-       ret = read_vhost_message(fd, &msg);
-       if (ret <= 0 || msg.request.master >= VHOST_USER_MAX) {
-               if (ret < 0)
-                       RTE_LOG(ERR, VHOST_CONFIG,
-                               "vhost read message failed\n");
-               else if (ret == 0)
-                       RTE_LOG(INFO, VHOST_CONFIG,
-                               "vhost peer closed\n");
-               else
-                       RTE_LOG(ERR, VHOST_CONFIG,
-                               "vhost read incorrect message\n");
-
+       if (msg.request.master >= VHOST_USER_MAX) {
+               RTE_LOG(ERR, VHOST_CONFIG, "vhost read incorrect message\n");
                return -1;
        }
 
@@ -1284,7 +1243,7 @@ vhost_user_msg_handler(int vid, int fd)
        case VHOST_USER_GET_FEATURES:
                msg.payload.u64 = vhost_user_get_features(dev);
                msg.size = sizeof(msg.payload.u64);
-               send_vhost_reply(fd, &msg);
+               send_vhost_reply(dev, &msg);
                break;
        case VHOST_USER_SET_FEATURES:
                ret = vhost_user_set_features(dev, msg.payload.u64);
@@ -1294,7 +1253,7 @@ vhost_user_msg_handler(int vid, int fd)
 
        case VHOST_USER_GET_PROTOCOL_FEATURES:
                vhost_user_get_protocol_features(dev, &msg);
-               send_vhost_reply(fd, &msg);
+               send_vhost_reply(dev, &msg);
                break;
        case VHOST_USER_SET_PROTOCOL_FEATURES:
                vhost_user_set_protocol_features(dev, msg.payload.u64);
@@ -1316,7 +1275,7 @@ vhost_user_msg_handler(int vid, int fd)
 
                /* it needs a reply */
                msg.size = sizeof(msg.payload.u64);
-               send_vhost_reply(fd, &msg);
+               send_vhost_reply(dev, &msg);
                break;
        case VHOST_USER_SET_LOG_FD:
                close(msg.fds[0]);
@@ -1336,7 +1295,7 @@ vhost_user_msg_handler(int vid, int fd)
        case VHOST_USER_GET_VRING_BASE:
                vhost_user_get_vring_base(dev, &msg);
                msg.size = sizeof(msg.payload.state);
-               send_vhost_reply(fd, &msg);
+               send_vhost_reply(dev, &msg);
                break;
 
        case VHOST_USER_SET_VRING_KICK:
@@ -1355,7 +1314,7 @@ vhost_user_msg_handler(int vid, int fd)
        case VHOST_USER_GET_QUEUE_NUM:
                msg.payload.u64 = VHOST_MAX_QUEUE_PAIRS;
                msg.size = sizeof(msg.payload.u64);
-               send_vhost_reply(fd, &msg);
+               send_vhost_reply(dev, &msg);
                break;
 
        case VHOST_USER_SET_VRING_ENABLE:
@@ -1386,7 +1345,7 @@ vhost_user_msg_handler(int vid, int fd)
        if (msg.flags & VHOST_USER_NEED_REPLY) {
                msg.payload.u64 = !!ret;
                msg.size = sizeof(msg.payload.u64);
-               send_vhost_reply(fd, &msg);
+               send_vhost_reply(dev, &msg);
        }
 
        if (!(dev->flags & VIRTIO_DEV_RUNNING) && virtio_is_ready(dev)) {
@@ -1421,7 +1380,7 @@ vhost_user_iotlb_miss(struct virtio_net *dev, uint64_t 
iova, uint8_t perm)
                },
        };
 
-       ret = send_vhost_message(dev->slave_req_fd, &msg);
+       ret = send_vhost_slave_req(dev, &msg);
        if (ret < 0) {
                RTE_LOG(ERR, VHOST_CONFIG,
                                "Failed to send IOTLB miss message (%d)\n",
-- 
2.14.3

Reply via email to