On 11.10.2018 12:24, Maxime Coquelin wrote:
> As soon as some ancillary data (fds) are received, it is copied
> without checking its length.
> 
> This patch adds the number of fds received to the message,
> which is set in read_vhost_message().
> 
> This is preliminary work to support sending fds to Qemu.
> 
> Signed-off-by: Dr. David Alan Gilbert <dgilb...@redhat.com>
> Signed-off-by: Maxime Coquelin <maxime.coque...@redhat.com>
> ---
>  lib/librte_vhost/socket.c     | 25 ++++++++++++++++++++-----
>  lib/librte_vhost/vhost_user.c |  2 +-
>  lib/librte_vhost/vhost_user.h |  4 +++-
>  3 files changed, 24 insertions(+), 7 deletions(-)
> 
> diff --git a/lib/librte_vhost/socket.c b/lib/librte_vhost/socket.c
> index d63031747..3b0287a26 100644
> --- a/lib/librte_vhost/socket.c
> +++ b/lib/librte_vhost/socket.c
> @@ -94,18 +94,24 @@ static struct vhost_user vhost_user = {
>       .mutex = PTHREAD_MUTEX_INITIALIZER,
>  };
>  
> -/* return bytes# of read on success or negative val on failure. */
> +/*
> + * return bytes# of read on success or negative val on failure. Update fdnum
> + * with number of fds read.
> + */
>  int
> -read_fd_message(int sockfd, char *buf, int buflen, int *fds, int fd_num)
> +read_fd_message(int sockfd, char *buf, int buflen, int *fds, int max_fds,
> +             int *fd_num)
>  {
>       struct iovec iov;
>       struct msghdr msgh;
> -     size_t fdsize = fd_num * sizeof(int);
> -     char control[CMSG_SPACE(fdsize)];
> +     char control[CMSG_SPACE(max_fds * sizeof(int))];
>       struct cmsghdr *cmsg;
>       int got_fds = 0;
> +     int *tmp_fds;
>       int ret;
>  
> +     *fd_num = 0;
> +
>       memset(&msgh, 0, sizeof(msgh));
>       iov.iov_base = buf;
>       iov.iov_len  = buflen;
> @@ -131,13 +137,22 @@ read_fd_message(int sockfd, char *buf, int buflen, int 
> *fds, int fd_num)
>               if ((cmsg->cmsg_level == SOL_SOCKET) &&
>                       (cmsg->cmsg_type == SCM_RIGHTS)) {
>                       got_fds = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int);
> +                     if (got_fds > max_fds) {

Hmm. I just noticed that 'msg_controllen' is set to receive
not more than max_fds descriptors. So, this case should not
be possible. We will receive MSG_CTRUNC and return before
the loop.

> +                             RTE_LOG(ERR, VHOST_CONFIG,
> +                                     "Received msg contains more fds than 
> supported\n");
> +                             tmp_fds = (int *)CMSG_DATA(cmsg);
> +                             while (got_fds--)
> +                                     close(tmp_fds[got_fds]);
> +                             return -1;
> +                     }
> +                     *fd_num = got_fds;
>                       memcpy(fds, CMSG_DATA(cmsg), got_fds * sizeof(int));
>                       break;
>               }
>       }
>  
>       /* Clear out unused file descriptors */
> -     while (got_fds < fd_num)
> +     while (got_fds < max_fds)
>               fds[got_fds++] = -1;
>  
>       return ret;
> diff --git a/lib/librte_vhost/vhost_user.c b/lib/librte_vhost/vhost_user.c
> index 83d3e6321..c1c5f35ff 100644
> --- a/lib/librte_vhost/vhost_user.c
> +++ b/lib/librte_vhost/vhost_user.c
> @@ -1509,7 +1509,7 @@ read_vhost_message(int sockfd, struct VhostUserMsg *msg)
>       int ret;
>  
>       ret = read_fd_message(sockfd, (char *)msg, VHOST_USER_HDR_SIZE,
> -             msg->fds, VHOST_MEMORY_MAX_NREGIONS);
> +             msg->fds, VHOST_MEMORY_MAX_NREGIONS, &msg->fd_num);
>       if (ret <= 0)
>               return ret;
>  
> diff --git a/lib/librte_vhost/vhost_user.h b/lib/librte_vhost/vhost_user.h
> index 62654f736..9a91d496b 100644
> --- a/lib/librte_vhost/vhost_user.h
> +++ b/lib/librte_vhost/vhost_user.h
> @@ -132,6 +132,7 @@ typedef struct VhostUserMsg {
>               VhostUserVringArea area;
>       } payload;
>       int fds[VHOST_MEMORY_MAX_NREGIONS];
> +     int fd_num;
>  } __attribute((packed)) VhostUserMsg;
>  
>  #define VHOST_USER_HDR_SIZE offsetof(VhostUserMsg, payload.u64)
> @@ -155,7 +156,8 @@ int vhost_user_iotlb_miss(struct virtio_net *dev, 
> uint64_t iova, uint8_t perm);
>  int vhost_user_host_notifier_ctrl(int vid, bool enable);
>  
>  /* socket.c */
> -int read_fd_message(int sockfd, char *buf, int buflen, int *fds, int fd_num);
> +int read_fd_message(int sockfd, char *buf, int buflen, int *fds, int max_fds,
> +             int *fd_num);
>  int send_fd_message(int sockfd, char *buf, int buflen, int *fds, int fd_num);
>  
>  #endif
> 

Reply via email to