On Thu, Jan 9, 2014 at 4:31 PM, Michael S. Tsirkin <m...@redhat.com> wrote: > On Thu, Jan 09, 2014 at 03:59:58PM +0100, Antonios Motakis wrote: >> Add structures for passing vhost-user messages over a unix domain socket. >> This is the equivalent of the existing vhost-kernel ioctls. >> >> Connect to the named unix domain socket. The system call sendmsg >> is used for communication. To be able to pass file descriptors >> between processes - we use SCM_RIGHTS type in the message control header. >> >> Signed-off-by: Antonios Motakis <a.mota...@virtualopensystems.com> >> Signed-off-by: Nikolay Nikolaev <n.nikol...@virtualopensystems.com> >> --- >> hw/virtio/vhost-backend.c | 268 >> +++++++++++++++++++++++++++++++++++++++++++++- >> 1 file changed, 263 insertions(+), 5 deletions(-) >> >> diff --git a/hw/virtio/vhost-backend.c b/hw/virtio/vhost-backend.c >> index 1d83e1d..b33d35f 100644 >> --- a/hw/virtio/vhost-backend.c >> +++ b/hw/virtio/vhost-backend.c >> @@ -11,34 +11,292 @@ >> #include "hw/virtio/vhost.h" >> #include "hw/virtio/vhost-backend.h" >> #include "qemu/error-report.h" >> +#include "qemu/sockets.h" >> >> #include <fcntl.h> >> #include <unistd.h> >> #include <sys/ioctl.h> >> +#include <sys/socket.h> >> +#include <sys/un.h> >> +#include <linux/vhost.h> >> + >> +#define VHOST_MEMORY_MAX_NREGIONS 8 >> +#define VHOST_USER_SOCKTO (300) /* msec */ >> + >> +typedef enum VhostUserRequest { >> + VHOST_USER_NONE = 0, >> + VHOST_USER_GET_FEATURES = 1, >> + VHOST_USER_SET_FEATURES = 2, >> + VHOST_USER_SET_OWNER = 3, >> + VHOST_USER_RESET_OWNER = 4, >> + VHOST_USER_SET_MEM_TABLE = 5, >> + VHOST_USER_SET_LOG_BASE = 6, >> + VHOST_USER_SET_LOG_FD = 7, >> + VHOST_USER_SET_VRING_NUM = 8, >> + VHOST_USER_SET_VRING_ADDR = 9, >> + VHOST_USER_SET_VRING_BASE = 10, >> + VHOST_USER_GET_VRING_BASE = 11, >> + VHOST_USER_SET_VRING_KICK = 12, >> + VHOST_USER_SET_VRING_CALL = 13, >> + VHOST_USER_SET_VRING_ERR = 14, >> + VHOST_USER_NET_SET_BACKEND = 15, >> + VHOST_USER_ECHO = 16, >> + VHOST_USER_MAX >> +} VhostUserRequest; >> + >> +typedef struct VhostUserMemoryRegion { >> + __u64 guest_phys_addr; >> + __u64 memory_size; >> + __u64 userspace_addr; > > Don't use __uXX like this. > Use posix uintXX_t types. > >> +} VhostUserMemoryRegion; >> + >> +typedef struct VhostUserMemory { >> + __u32 nregions; >> + __u32 padding; >> + VhostUserMemoryRegion regions[VHOST_MEMORY_MAX_NREGIONS]; >> +} VhostUserMemory; >> + >> +typedef struct VhostUserMsg { >> + VhostUserRequest request; >> + >> +#define VHOST_USER_VERSION_MASK (0x3) >> +#define VHOST_USER_REPLY_MASK (0x1<<2) >> + __u32 flags; >> + __u32 size; /* the following payload size */ >> + union { >> + __u64 u64; >> + int fd; >> + struct vhost_vring_state state; >> + struct vhost_vring_addr addr; >> + struct vhost_vring_file file; > > Hmm I don't get it. fd is passes in anciliarry data, right? > what are the fd #s doing in data portion? > >> + VhostUserMemory memory; >> + }; >> +} QEMU_PACKED VhostUserMsg; >> + >> +#define MEMB_SIZE(t, m) (sizeof(((t *)0)->m)) > > That's unnecessarily tricky. > Just do > > static VhostUserMsg m; > > And then you can use sizeof(m.request) without issues and macros. > >> +#define VHOST_USER_HDR_SIZE (MEMB_SIZE(VhostUserMsg, request) \ >> + + MEMB_SIZE(VhostUserMsg, flags) \ >> + + MEMB_SIZE(VhostUserMsg, size)) > > or just split header to a separate structure and use sizeof on that. > >> + >> +/* The version of the protocol we support */ >> +#define VHOST_USER_VERSION (0x1) >> + >> +static int vhost_user_cleanup(struct vhost_dev *dev); > > dont forward declare static functions. > order them logically instead. > >> + >> +static int vhost_user_recv(int fd, VhostUserMsg *msg) >> +{ >> + ssize_t r; >> + __u8 *p = (__u8 *)msg; >> + >> + /* read the header */ >> + do { >> + r = read(fd, p, VHOST_USER_HDR_SIZE); >> + } while (r < 0 && errno == EINTR); > > What if you get a partial buffer? > You use SOCK_STREAM so I think it can happen right? > > >> + >> + if (r > 0) { >> + p += VHOST_USER_HDR_SIZE; >> + /* read the payload */ >> + do { >> + r = read(fd, p, msg->size); >> + } while (r < 0 && errno == EINTR); > > That's a security bug. > What if msg->size is bigger than sizeof *msg? > you will corrupt memory. > You must validate size. > >> + } >> + >> + if (r < 0) { >> + error_report("Failed to read msg, reason: %s\n", strerror(errno)); >> + } >> + >> + return (r < 0) ? -1 : 0; >> +} >> + >> +static int vhost_user_send_fds(int fd, const VhostUserMsg *msg, int *fds, >> + size_t fd_num) >> +{ >> + int r; >> + >> + struct msghdr msgh; >> + struct iovec iov[1]; > > same as struct iovec iov only longer and more confusing? >> + >> + size_t fd_size = fd_num * sizeof(int); >> + char control[CMSG_SPACE(fd_size)]; >> + struct cmsghdr *cmsg; >> + >> + memset(&msgh, 0, sizeof(msgh)); >> + memset(control, 0, sizeof(control)); >> + >> + /* set the payload */ >> + iov[0].iov_base = (void *)msg; > > it's void * already, no need for cast > >> + iov[0].iov_len = VHOST_USER_HDR_SIZE + msg->size; >> + >> + msgh.msg_iov = iov; >> + msgh.msg_iovlen = 1; >> + >> + if (fd_num) { >> + msgh.msg_control = control; >> + msgh.msg_controllen = sizeof(control); >> + >> + cmsg = CMSG_FIRSTHDR(&msgh); >> + >> + cmsg->cmsg_len = CMSG_LEN(fd_size); >> + cmsg->cmsg_level = SOL_SOCKET; >> + cmsg->cmsg_type = SCM_RIGHTS; >> + memcpy(CMSG_DATA(cmsg), fds, fd_size); >> + } else { >> + msgh.msg_control = 0; >> + msgh.msg_controllen = 0; >> + } >> + >> + do { >> + r = sendmsg(fd, &msgh, 0); >> + } while (r < 0 && errno == EINTR); > > again I think this can write message partially for SOCK_STREAM. > >> + >> + if (r < 0) { >> + error_report("Failed to send msg(%d), reason: %s\n", >> + msg->request, strerror(errno)); >> + } else { >> + r = 0; >> + } >> + >> + return r; >> +} >> + >> +static int vhost_user_echo(struct vhost_dev *dev) >> +{ >> + VhostUserMsg msg; >> + int fd = dev->control; >> + >> + if (fd < 0) { >> + return -1; >> + } >> + >> + /* check connection */ >> + msg.request = VHOST_USER_ECHO; >> + msg.flags = VHOST_USER_VERSION; >> + msg.size = 0; >> + >> + if (vhost_user_send_fds(fd, &msg, 0, 0) < 0) { >> + error_report("ECHO failed\n"); >> + return -1; >> + } >> + >> + if (vhost_user_recv(fd, &msg) < 0) { >> + error_report("ECHO failed\n"); >> + return -1; >> + } >> + >> + if ((msg.flags & VHOST_USER_REPLY_MASK) == 0 || >> + (msg.flags & VHOST_USER_VERSION_MASK) != VHOST_USER_VERSION) { >> + error_report("ECHO failed\n"); >> + return -1; >> + } >> + >> + return 0; >> +} >> >> static int vhost_user_call(struct vhost_dev *dev, unsigned long int request, >> void *arg) >> { >> + int fd = dev->control; >> + VhostUserMsg msg; >> + int result = 0; >> + int fds[VHOST_MEMORY_MAX_NREGIONS]; >> + size_t fd_num = 0; >> + >> assert(dev->vhost_ops->backend_type == VHOST_BACKEND_TYPE_USER); >> - error_report("vhost_user_call not implemented\n"); >> >> - return -1; >> + if (fd < 0) { >> + return -1; >> + } >> + >> + msg.request = VHOST_USER_NONE; >> + msg.flags = VHOST_USER_VERSION; >> + msg.size = 0; >> + >> + switch (request) { >> + default: >> + error_report("vhost-user trying to send unhandled ioctl\n"); >> + return -1; >> + break; >> + } >> + >> + result = vhost_user_send_fds(fd, &msg, fds, fd_num); >> + >> + return result; >> } >> >> static int vhost_user_init(struct vhost_dev *dev, const char *devpath) >> { >> + int fd = -1; >> + struct sockaddr_un un; >> + struct timeval tv; >> + size_t len; >> + >> assert(dev->vhost_ops->backend_type == VHOST_BACKEND_TYPE_USER); >> - error_report("vhost_user_init not implemented\n"); >> >> + /* Create the socket */ >> + fd = qemu_socket(AF_UNIX, SOCK_STREAM, 0); >> + if (fd == -1) { >> + error_report("socket: %s", strerror(errno)); >> + return -1; >> + } >> + >> + /* Set socket options */ >> + tv.tv_sec = VHOST_USER_SOCKTO / 1000; >> + tv.tv_usec = (VHOST_USER_SOCKTO % 1000) * 1000; >> + >> + if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(struct timeval)) >> + == -1) { >> + error_report("setsockopt SO_SNDTIMEO: %s", strerror(errno)); >> + goto fail; >> + } >> + >> + if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(struct timeval)) >> + == -1) { >> + error_report("setsockopt SO_RCVTIMEO: %s", strerror(errno)); >> + goto fail; >> + } >> + >> + /* Connect */ >> + un.sun_family = AF_UNIX; >> + strcpy(un.sun_path, devpath); >> + >> + len = sizeof(un.sun_family) + strlen(devpath); >> + >> + if (connect(fd, (struct sockaddr *) &un, len) == -1) { >> + error_report("connect: %s", strerror(errno)); >> + goto fail; >> + } >> + >> + /* Cleanup if there is previous connection left */ >> + if (dev->control >= 0) { >> + vhost_user_cleanup(dev); >> + } >> + dev->control = fd; >> + >> + if (vhost_user_echo(dev) < 0) { >> + dev->control = -1; >> + goto fail; >> + } >> + >> + return fd; >> + >> +fail: >> + close(fd); >> return -1; >> + >> } >> >> static int vhost_user_cleanup(struct vhost_dev *dev) >> { >> + int r = 0; >> + >> assert(dev->vhost_ops->backend_type == VHOST_BACKEND_TYPE_USER); >> - error_report("vhost_user_cleanup not implemented\n"); >> >> - return -1; >> + if (dev->control >= 0) { >> + r = close(dev->control); >> + } >> + dev->control = -1; >> + >> + return r; >> } >> >> static const VhostOps user_ops = { >> -- >> 1.8.3.2 >>
Ack, thanks for your feeedback. Best regards, Antonios