translate_ring_addresses and numa_realloc may change a virtio device and
virtio queue. Callers of those helpers must be extra careful and refresh
any reference to old data.

Change those functions prototype as a way to hint about this issue and
always ask for an indirect pointer.

Besides, when reallocating the device and queue, the code already made
sure it will return a pointer to a valid device. The checks on such
returned pointer can be removed.

Signed-off-by: David Marchand <david.march...@redhat.com>
---
 lib/vhost/vhost_user.c | 144 +++++++++++++++++++----------------------
 1 file changed, 66 insertions(+), 78 deletions(-)

diff --git a/lib/vhost/vhost_user.c b/lib/vhost/vhost_user.c
index 91d40e32fc..46d4a02c1e 100644
--- a/lib/vhost/vhost_user.c
+++ b/lib/vhost/vhost_user.c
@@ -493,11 +493,11 @@ vhost_user_set_vring_num(struct virtio_net **pdev,
  * make them on the same numa node as the memory of vring descriptor.
  */
 #ifdef RTE_LIBRTE_VHOST_NUMA
-static struct virtio_net*
-numa_realloc(struct virtio_net *dev, int index)
+static void
+numa_realloc(struct virtio_net **pdev, struct vhost_virtqueue **pvq, int index)
 {
        int node, dev_node;
-       struct virtio_net *old_dev;
+       struct virtio_net *dev;
        struct vhost_virtqueue *vq;
        struct batch_copy_elem *bce;
        struct guest_page *gp;
@@ -505,34 +505,35 @@ numa_realloc(struct virtio_net *dev, int index)
        size_t mem_size;
        int ret;
 
-       old_dev = dev;
-       vq = dev->virtqueue[index];
+       dev = *pdev;
+       vq = *pvq;
 
        /*
         * If VQ is ready, it is too late to reallocate, it certainly already
         * happened anyway on VHOST_USER_SET_VRING_ADRR.
         */
        if (vq->ready)
-               return dev;
+               return;
 
        ret = get_mempolicy(&node, NULL, 0, vq->desc, MPOL_F_NODE | 
MPOL_F_ADDR);
        if (ret) {
                VHOST_LOG_CONFIG(dev->ifname, ERR,
                        "unable to get virtqueue %d numa information.\n",
                        index);
-               return dev;
+               return;
        }
 
        if (node == vq->numa_node)
                goto out_dev_realloc;
 
-       vq = rte_realloc_socket(vq, sizeof(*vq), 0, node);
+       vq = rte_realloc_socket(*pvq, sizeof(**pvq), 0, node);
        if (!vq) {
                VHOST_LOG_CONFIG(dev->ifname, ERR,
                        "failed to realloc virtqueue %d on node %d\n",
                        index, node);
-               return dev;
+               return;
        }
+       *pvq = vq;
 
        if (vq != dev->virtqueue[index]) {
                VHOST_LOG_CONFIG(dev->ifname, INFO, "reallocated virtqueue on 
node %d\n", node);
@@ -549,7 +550,7 @@ numa_realloc(struct virtio_net *dev, int index)
                        VHOST_LOG_CONFIG(dev->ifname, ERR,
                                "failed to realloc shadow packed on node %d\n",
                                node);
-                       return dev;
+                       return;
                }
                vq->shadow_used_packed = sup;
        } else {
@@ -561,7 +562,7 @@ numa_realloc(struct virtio_net *dev, int index)
                        VHOST_LOG_CONFIG(dev->ifname, ERR,
                                "failed to realloc shadow split on node %d\n",
                                node);
-                       return dev;
+                       return;
                }
                vq->shadow_used_split = sus;
        }
@@ -572,7 +573,7 @@ numa_realloc(struct virtio_net *dev, int index)
                VHOST_LOG_CONFIG(dev->ifname, ERR,
                        "failed to realloc batch copy elem on node %d\n",
                        node);
-               return dev;
+               return;
        }
        vq->batch_copy_elems = bce;
 
@@ -584,7 +585,7 @@ numa_realloc(struct virtio_net *dev, int index)
                        VHOST_LOG_CONFIG(dev->ifname, ERR,
                                "failed to realloc log cache on node %d\n",
                                node);
-                       return dev;
+                       return;
                }
                vq->log_cache = lc;
        }
@@ -597,7 +598,7 @@ numa_realloc(struct virtio_net *dev, int index)
                        VHOST_LOG_CONFIG(dev->ifname, ERR,
                                "failed to realloc resubmit inflight on node 
%d\n",
                                node);
-                       return dev;
+                       return;
                }
                vq->resubmit_inflight = ri;
 
@@ -610,7 +611,7 @@ numa_realloc(struct virtio_net *dev, int index)
                                VHOST_LOG_CONFIG(dev->ifname, ERR,
                                        "failed to realloc resubmit list on 
node %d\n",
                                        node);
-                               return dev;
+                               return;
                        }
                        ri->resubmit_list = rd;
                }
@@ -621,22 +622,23 @@ numa_realloc(struct virtio_net *dev, int index)
 out_dev_realloc:
 
        if (dev->flags & VIRTIO_DEV_RUNNING)
-               return dev;
+               return;
 
        ret = get_mempolicy(&dev_node, NULL, 0, dev, MPOL_F_NODE | MPOL_F_ADDR);
        if (ret) {
                VHOST_LOG_CONFIG(dev->ifname, ERR, "unable to get numa 
information.\n");
-               return dev;
+               return;
        }
 
        if (dev_node == node)
-               return dev;
+               return;
 
-       dev = rte_realloc_socket(old_dev, sizeof(*dev), 0, node);
+       dev = rte_realloc_socket(*pdev, sizeof(**pdev), 0, node);
        if (!dev) {
-               VHOST_LOG_CONFIG(old_dev->ifname, ERR, "failed to realloc dev 
on node %d\n", node);
-               return old_dev;
+               VHOST_LOG_CONFIG((*pdev)->ifname, ERR, "failed to realloc dev 
on node %d\n", node);
+               return;
        }
+       *pdev = dev;
 
        VHOST_LOG_CONFIG(dev->ifname, INFO, "reallocated device on node %d\n", 
node);
        vhost_devices[dev->vid] = dev;
@@ -648,7 +650,7 @@ numa_realloc(struct virtio_net *dev, int index)
                VHOST_LOG_CONFIG(dev->ifname, ERR,
                        "failed to realloc mem table on node %d\n",
                        node);
-               return dev;
+               return;
        }
        dev->mem = mem;
 
@@ -658,17 +660,17 @@ numa_realloc(struct virtio_net *dev, int index)
                VHOST_LOG_CONFIG(dev->ifname, ERR,
                        "failed to realloc guest pages on node %d\n",
                        node);
-               return dev;
+               return;
        }
        dev->guest_pages = gp;
-
-       return dev;
 }
 #else
-static struct virtio_net*
-numa_realloc(struct virtio_net *dev, int index __rte_unused)
+static void
+numa_realloc(struct virtio_net **pdev, struct vhost_virtqueue **pvq, int index)
 {
-       return dev;
+       RTE_SET_USED(pdev);
+       RTE_SET_USED(pvq);
+       RTE_SET_USED(index);
 }
 #endif
 
@@ -738,88 +740,92 @@ log_addr_to_gpa(struct virtio_net *dev, struct 
vhost_virtqueue *vq)
        return log_gpa;
 }
 
-static struct virtio_net *
-translate_ring_addresses(struct virtio_net *dev, int vq_index)
+static void
+translate_ring_addresses(struct virtio_net **pdev, struct vhost_virtqueue 
**pvq,
+       int vq_index)
 {
-       struct vhost_virtqueue *vq = dev->virtqueue[vq_index];
-       struct vhost_vring_addr *addr = &vq->ring_addrs;
+       struct vhost_virtqueue *vq;
+       struct virtio_net *dev;
        uint64_t len, expected_len;
 
-       if (addr->flags & (1 << VHOST_VRING_F_LOG)) {
+       dev = *pdev;
+       vq = *pvq;
+
+       if (vq->ring_addrs.flags & (1 << VHOST_VRING_F_LOG)) {
                vq->log_guest_addr =
                        log_addr_to_gpa(dev, vq);
                if (vq->log_guest_addr == 0) {
                        VHOST_LOG_CONFIG(dev->ifname, DEBUG, "failed to map 
log_guest_addr.\n");
-                       return dev;
+                       return;
                }
        }
 
        if (vq_is_packed(dev)) {
                len = sizeof(struct vring_packed_desc) * vq->size;
                vq->desc_packed = (struct vring_packed_desc *)(uintptr_t)
-                       ring_addr_to_vva(dev, vq, addr->desc_user_addr, &len);
+                       ring_addr_to_vva(dev, vq, 
vq->ring_addrs.desc_user_addr, &len);
                if (vq->desc_packed == NULL ||
                                len != sizeof(struct vring_packed_desc) *
                                vq->size) {
                        VHOST_LOG_CONFIG(dev->ifname, DEBUG, "failed to map 
desc_packed ring.\n");
-                       return dev;
+                       return;
                }
 
-               dev = numa_realloc(dev, vq_index);
-               vq = dev->virtqueue[vq_index];
-               addr = &vq->ring_addrs;
+               numa_realloc(&dev, &vq, vq_index);
+               *pdev = dev;
+               *pvq = vq;
 
                len = sizeof(struct vring_packed_desc_event);
                vq->driver_event = (struct vring_packed_desc_event *)
                                        (uintptr_t)ring_addr_to_vva(dev,
-                                       vq, addr->avail_user_addr, &len);
+                                       vq, vq->ring_addrs.avail_user_addr, 
&len);
                if (vq->driver_event == NULL ||
                                len != sizeof(struct vring_packed_desc_event)) {
                        VHOST_LOG_CONFIG(dev->ifname, DEBUG,
                                "failed to find driver area address.\n");
-                       return dev;
+                       return;
                }
 
                len = sizeof(struct vring_packed_desc_event);
                vq->device_event = (struct vring_packed_desc_event *)
                                        (uintptr_t)ring_addr_to_vva(dev,
-                                       vq, addr->used_user_addr, &len);
+                                       vq, vq->ring_addrs.used_user_addr, 
&len);
                if (vq->device_event == NULL ||
                                len != sizeof(struct vring_packed_desc_event)) {
                        VHOST_LOG_CONFIG(dev->ifname, DEBUG,
                                "failed to find device area address.\n");
-                       return dev;
+                       return;
                }
 
                vq->access_ok = true;
-               return dev;
+               return;
        }
 
        /* The addresses are converted from QEMU virtual to Vhost virtual. */
        if (vq->desc && vq->avail && vq->used)
-               return dev;
+               return;
 
        len = sizeof(struct vring_desc) * vq->size;
        vq->desc = (struct vring_desc *)(uintptr_t)ring_addr_to_vva(dev,
-                       vq, addr->desc_user_addr, &len);
+                       vq, vq->ring_addrs.desc_user_addr, &len);
        if (vq->desc == 0 || len != sizeof(struct vring_desc) * vq->size) {
                VHOST_LOG_CONFIG(dev->ifname, DEBUG, "failed to map desc 
ring.\n");
-               return dev;
+               return;
        }
 
-       dev = numa_realloc(dev, vq_index);
-       vq = dev->virtqueue[vq_index];
-       addr = &vq->ring_addrs;
+       numa_realloc(&dev, &vq, vq_index);
+       *pdev = dev;
+       *pvq = vq;
 
        len = sizeof(struct vring_avail) + sizeof(uint16_t) * vq->size;
        if (dev->features & (1ULL << VIRTIO_RING_F_EVENT_IDX))
                len += sizeof(uint16_t);
        expected_len = len;
        vq->avail = (struct vring_avail *)(uintptr_t)ring_addr_to_vva(dev,
-                       vq, addr->avail_user_addr, &len);
+                       vq, vq->ring_addrs.avail_user_addr, &len);
        if (vq->avail == 0 || len != expected_len) {
                VHOST_LOG_CONFIG(dev->ifname, DEBUG, "failed to map avail 
ring.\n");
-               return dev;
+               return;
        }
 
        len = sizeof(struct vring_used) +
@@ -828,10 +834,10 @@ translate_ring_addresses(struct virtio_net *dev, int 
vq_index)
                len += sizeof(uint16_t);
        expected_len = len;
        vq->used = (struct vring_used *)(uintptr_t)ring_addr_to_vva(dev,
-                       vq, addr->used_user_addr, &len);
+                       vq, vq->ring_addrs.used_user_addr, &len);
        if (vq->used == 0 || len != expected_len) {
                VHOST_LOG_CONFIG(dev->ifname, DEBUG, "failed to map used 
ring.\n");
-               return dev;
+               return;
        }
 
        if (vq->last_used_idx != vq->used->idx) {
@@ -850,8 +856,6 @@ translate_ring_addresses(struct virtio_net *dev, int 
vq_index)
        VHOST_LOG_CONFIG(dev->ifname, DEBUG, "mapped address avail: %p\n", 
vq->avail);
        VHOST_LOG_CONFIG(dev->ifname, DEBUG, "mapped address used: %p\n", 
vq->used);
        VHOST_LOG_CONFIG(dev->ifname, DEBUG, "log_guest_addr: %" PRIx64 "\n", 
vq->log_guest_addr);
-
-       return dev;
 }
 
 /*
@@ -887,10 +891,7 @@ vhost_user_set_vring_addr(struct virtio_net **pdev,
        if ((vq->enabled && (dev->features &
                                (1ULL << VHOST_USER_F_PROTOCOL_FEATURES))) ||
                        access_ok) {
-               dev = translate_ring_addresses(dev, 
ctx->msg.payload.addr.index);
-               if (!dev)
-                       return RTE_VHOST_MSG_RESULT_ERR;
-
+               translate_ring_addresses(&dev, &vq, 
ctx->msg.payload.addr.index);
                *pdev = dev;
        }
 
@@ -1396,12 +1397,7 @@ vhost_user_set_mem_table(struct virtio_net **pdev,
                         */
                        vring_invalidate(dev, vq);
 
-                       dev = translate_ring_addresses(dev, i);
-                       if (!dev) {
-                               dev = *pdev;
-                               goto free_mem_table;
-                       }
-
+                       translate_ring_addresses(&dev, &vq, i);
                        *pdev = dev;
                }
        }
@@ -2029,17 +2025,9 @@ vhost_user_set_vring_kick(struct virtio_net **pdev,
                file.index, file.fd);
 
        /* Interpret ring addresses only when ring is started. */
-       dev = translate_ring_addresses(dev, file.index);
-       if (!dev) {
-               if (file.fd != VIRTIO_INVALID_EVENTFD)
-                       close(file.fd);
-
-               return RTE_VHOST_MSG_RESULT_ERR;
-       }
-
-       *pdev = dev;
-
        vq = dev->virtqueue[file.index];
+       translate_ring_addresses(&dev, &vq, file.index);
+       *pdev = dev;
 
        /*
         * When VHOST_USER_F_PROTOCOL_FEATURES is not negotiated,
@@ -2595,8 +2583,8 @@ vhost_user_iotlb_msg(struct virtio_net **pdev,
 
                        if (is_vring_iotlb(dev, vq, imsg)) {
                                rte_spinlock_lock(&vq->access_lock);
-                               *pdev = dev = translate_ring_addresses(dev, i);
-                               vq = dev->virtqueue[i];
+                               translate_ring_addresses(&dev, &vq, i);
+                               *pdev = dev;
                                rte_spinlock_unlock(&vq->access_lock);
                        }
                }
-- 
2.36.1

Reply via email to