The common vhost code only supported a single mmap per device. vhost-user
worked around this by saving the address/length/fd of each mmap after the end
of the rte_virtio_memory struct. This only works if the vhost-user code frees
dev->mem, since the common code is unaware of the extra info. The
VHOST_USER_RESET_OWNER message is one situation where the common code frees
dev->mem and leaks the fds and mappings. This happens every time I shut down a
VM.

The new code does not keep the fds around since they aren't required for
munmap. It saves the address/length in a new structure which is read by the
common code.

The vhost-cuse changes are only compile tested.

Signed-off-by: Rich Lane <rlane at bigswitch.com>
---
 lib/librte_vhost/rte_virtio_net.h             | 14 +++--
 lib/librte_vhost/vhost_cuse/virtio-net-cdev.c | 24 ++++---
 lib/librte_vhost/vhost_user/virtio-net-user.c | 90 ++++++++-------------------
 lib/librte_vhost/virtio-net.c                 | 24 ++++++-
 lib/librte_vhost/virtio-net.h                 |  3 +
 5 files changed, 75 insertions(+), 80 deletions(-)

diff --git a/lib/librte_vhost/rte_virtio_net.h 
b/lib/librte_vhost/rte_virtio_net.h
index 10dcb90..5233879 100644
--- a/lib/librte_vhost/rte_virtio_net.h
+++ b/lib/librte_vhost/rte_virtio_net.h
@@ -144,16 +144,22 @@ struct virtio_memory_regions {
        uint64_t        address_offset;         /**< Offset of region for 
address translation. */
 };

+/**
+ * Record a memory mapping so that it can be munmap'd later.
+ */
+struct virtio_memory_mapping {
+       void *addr;
+       size_t length;
+};

 /**
  * Memory structure includes region and mapping information.
  */
 struct virtio_memory {
-       uint64_t        base_address;   /**< Base QEMU userspace address of the 
memory file. */
-       uint64_t        mapped_address; /**< Mapped address of memory file base 
in our applications memory space. */
-       uint64_t        mapped_size;    /**< Total size of memory file. */
        uint32_t        nregions;       /**< Number of memory regions. */
-       struct virtio_memory_regions      regions[0]; /**< Memory region 
information. */
+       uint32_t        nmappings;      /**< Number of memory mappings */
+       struct virtio_memory_regions    regions[VHOST_MEMORY_MAX_NREGIONS]; 
/**< Memory region information. */
+       struct virtio_memory_mapping    mappings[VHOST_MEMORY_MAX_NREGIONS]; 
/**< Memory mappings */
 };

 /**
diff --git a/lib/librte_vhost/vhost_cuse/virtio-net-cdev.c 
b/lib/librte_vhost/vhost_cuse/virtio-net-cdev.c
index ae2c3fa..1cd0c52 100644
--- a/lib/librte_vhost/vhost_cuse/virtio-net-cdev.c
+++ b/lib/librte_vhost/vhost_cuse/virtio-net-cdev.c
@@ -278,15 +278,20 @@ cuse_set_mem_table(struct vhost_device_ctx ctx,
        if (dev == NULL)
                return -1;

-       if (dev->mem && dev->mem->mapped_address) {
-               munmap((void *)(uintptr_t)dev->mem->mapped_address,
-                       (size_t)dev->mem->mapped_size);
-               free(dev->mem);
+       if (nregions > VHOST_MEMORY_MAX_NREGIONS) {
+               RTE_LOG(ERR, VHOST_CONFIG,
+                       "(%"PRIu64") Too many memory regions (%u, max %u)\n",
+                       dev->device_fh, nregions,
+                       VHOST_MEMORY_MAX_NREGIONS);
+               return -1;
+       }
+
+       if (dev->mem) {
+               rte_vhost_free_mem(dev->mem);
                dev->mem = NULL;
        }

-       dev->mem = calloc(1, sizeof(struct virtio_memory) +
-               sizeof(struct virtio_memory_regions) * nregions);
+       dev->mem = calloc(1, sizeof(*dev->mem));
        if (dev->mem == NULL) {
                RTE_LOG(ERR, VHOST_CONFIG,
                        "(%"PRIu64") Failed to allocate memory for dev->mem\n",
@@ -325,9 +330,10 @@ cuse_set_mem_table(struct vhost_device_ctx ctx,
                                dev->mem = NULL;
                                return -1;
                        }
-                       dev->mem->mapped_address = mapped_address;
-                       dev->mem->base_address = base_address;
-                       dev->mem->mapped_size = mapped_size;
+
+                       rte_vhost_add_mapping(dev->mem,
+                               (void *)(uintptr_t)mapped_address,
+                               mapped_size);
                }
        }

diff --git a/lib/librte_vhost/vhost_user/virtio-net-user.c 
b/lib/librte_vhost/vhost_user/virtio-net-user.c
index 2934d1c..492927a 100644
--- a/lib/librte_vhost/vhost_user/virtio-net-user.c
+++ b/lib/librte_vhost/vhost_user/virtio-net-user.c
@@ -48,18 +48,6 @@
 #include "vhost-net-user.h"
 #include "vhost-net.h"

-struct orig_region_map {
-       int fd;
-       uint64_t mapped_address;
-       uint64_t mapped_size;
-       uint64_t blksz;
-};
-
-#define orig_region(ptr, nregions) \
-       ((struct orig_region_map *)RTE_PTR_ADD((ptr), \
-               sizeof(struct virtio_memory) + \
-               sizeof(struct virtio_memory_regions) * (nregions)))
-
 static uint64_t
 get_blk_size(int fd)
 {
@@ -69,34 +57,15 @@ get_blk_size(int fd)
        return (uint64_t)stat.st_blksize;
 }

-static void
-free_mem_region(struct virtio_net *dev)
-{
-       struct orig_region_map *region;
-       unsigned int idx;
-
-       if (!dev || !dev->mem)
-               return;
-
-       region = orig_region(dev->mem, dev->mem->nregions);
-       for (idx = 0; idx < dev->mem->nregions; idx++) {
-               if (region[idx].mapped_address) {
-                       munmap((void *)(uintptr_t)region[idx].mapped_address,
-                                       region[idx].mapped_size);
-                       close(region[idx].fd);
-               }
-       }
-}
-
 int
 user_set_mem_table(struct vhost_device_ctx ctx, struct VhostUserMsg *pmsg)
 {
        struct VhostUserMemory memory = pmsg->payload.memory;
        struct virtio_memory_regions *pregion;
-       uint64_t mapped_address, mapped_size;
+       void *mapped_address;
+       uint64_t mapped_size;
        struct virtio_net *dev;
        unsigned int idx = 0;
-       struct orig_region_map *pregion_orig;
        uint64_t alignment;

        /* unmap old memory regions one by one*/
@@ -104,20 +73,24 @@ user_set_mem_table(struct vhost_device_ctx ctx, struct 
VhostUserMsg *pmsg)
        if (dev == NULL)
                return -1;

+       if (memory.nregions > VHOST_MEMORY_MAX_NREGIONS) {
+               RTE_LOG(ERR, VHOST_CONFIG,
+                       "(%"PRIu64") Too many memory regions (%u, max %u)\n",
+                       dev->device_fh, memory.nregions,
+                       VHOST_MEMORY_MAX_NREGIONS);
+               return -1;
+       }
+
        /* Remove from the data plane. */
        if (dev->flags & VIRTIO_DEV_RUNNING)
                notify_ops->destroy_device(dev);

        if (dev->mem) {
-               free_mem_region(dev);
-               free(dev->mem);
+               rte_vhost_free_mem(dev->mem);
                dev->mem = NULL;
        }

-       dev->mem = calloc(1,
-               sizeof(struct virtio_memory) +
-               sizeof(struct virtio_memory_regions) * memory.nregions +
-               sizeof(struct orig_region_map) * memory.nregions);
+       dev->mem = calloc(1, sizeof(*dev->mem));
        if (dev->mem == NULL) {
                RTE_LOG(ERR, VHOST_CONFIG,
                        "(%"PRIu64") Failed to allocate memory for dev->mem\n",
@@ -126,7 +99,6 @@ user_set_mem_table(struct vhost_device_ctx ctx, struct 
VhostUserMsg *pmsg)
        }
        dev->mem->nregions = memory.nregions;

-       pregion_orig = orig_region(dev->mem, memory.nregions);
        for (idx = 0; idx < memory.nregions; idx++) {
                pregion = &dev->mem->regions[idx];
                pregion->guest_phys_address =
@@ -154,7 +126,7 @@ user_set_mem_table(struct vhost_device_ctx ctx, struct 
VhostUserMsg *pmsg)
                alignment = get_blk_size(pmsg->fds[idx]);
                mapped_size = RTE_ALIGN_CEIL(mapped_size, alignment);

-               mapped_address = (uint64_t)(uintptr_t)mmap(NULL,
+               mapped_address = mmap(NULL,
                        mapped_size,
                        PROT_READ | PROT_WRITE, MAP_SHARED,
                        pmsg->fds[idx],
@@ -163,33 +135,23 @@ user_set_mem_table(struct vhost_device_ctx ctx, struct 
VhostUserMsg *pmsg)
                RTE_LOG(INFO, VHOST_CONFIG,
                        "mapped region %d fd:%d to:%p sz:0x%"PRIx64" "
                        "off:0x%"PRIx64" align:0x%"PRIx64"\n",
-                       idx, pmsg->fds[idx], (void *)(uintptr_t)mapped_address,
+                       idx, pmsg->fds[idx], mapped_address,
                        mapped_size, memory.regions[idx].mmap_offset,
                        alignment);

-               if (mapped_address == (uint64_t)(uintptr_t)MAP_FAILED) {
+               if (mapped_address == MAP_FAILED) {
                        RTE_LOG(ERR, VHOST_CONFIG,
                                "mmap qemu guest failed.\n");
                        goto err_mmap;
                }

-               pregion_orig[idx].mapped_address = mapped_address;
-               pregion_orig[idx].mapped_size = mapped_size;
-               pregion_orig[idx].blksz = alignment;
-               pregion_orig[idx].fd = pmsg->fds[idx];
+               rte_vhost_add_mapping(dev->mem, mapped_address, mapped_size);

-               mapped_address +=  memory.regions[idx].mmap_offset;
-
-               pregion->address_offset = mapped_address -
+               pregion->address_offset =
+                       (uint64_t)(uintptr_t)mapped_address +
+                       memory.regions[idx].mmap_offset -
                        pregion->guest_phys_address;

-               if (memory.regions[idx].guest_phys_addr == 0) {
-                       dev->mem->base_address =
-                               memory.regions[idx].userspace_addr;
-                       dev->mem->mapped_address =
-                               pregion->address_offset;
-               }
-
                LOG_DEBUG(VHOST_CONFIG,
                        "REGION: %u GPA: %p QEMU VA: %p SIZE (%"PRIu64")\n",
                        idx,
@@ -198,15 +160,15 @@ user_set_mem_table(struct vhost_device_ctx ctx, struct 
VhostUserMsg *pmsg)
                         pregion->memory_size);
        }

+       for (idx = 0; idx < memory.nregions; idx++)
+               close(pmsg->fds[idx]);
+
        return 0;

 err_mmap:
-       while (idx--) {
-               munmap((void *)(uintptr_t)pregion_orig[idx].mapped_address,
-                               pregion_orig[idx].mapped_size);
-               close(pregion_orig[idx].fd);
-       }
-       free(dev->mem);
+       for (idx = 0; idx < memory.nregions; idx++)
+               close(pmsg->fds[idx]);
+       rte_vhost_free_mem(dev->mem);
        dev->mem = NULL;
        return -1;
 }
@@ -347,7 +309,7 @@ user_destroy_device(struct vhost_device_ctx ctx)
                notify_ops->destroy_device(dev);

        if (dev && dev->mem) {
-               free_mem_region(dev);
+               rte_vhost_free_mem(dev->mem);
                free(dev->mem);
                dev->mem = NULL;
        }
diff --git a/lib/librte_vhost/virtio-net.c b/lib/librte_vhost/virtio-net.c
index de78a0f..cd0e09e 100644
--- a/lib/librte_vhost/virtio-net.c
+++ b/lib/librte_vhost/virtio-net.c
@@ -201,9 +201,7 @@ cleanup_device(struct virtio_net *dev, int destroy)

        /* Unmap QEMU memory file if mapped. */
        if (dev->mem) {
-               munmap((void *)(uintptr_t)dev->mem->mapped_address,
-                       (size_t)dev->mem->mapped_size);
-               free(dev->mem);
+               rte_vhost_free_mem(dev->mem);
                dev->mem = NULL;
        }

@@ -897,3 +895,23 @@ rte_vhost_driver_callback_register(struct 
virtio_net_device_ops const * const op

        return 0;
 }
+
+void
+rte_vhost_free_mem(struct virtio_memory *mem)
+{
+       unsigned i;
+
+       for (i = 0; i < mem->nmappings; i++)
+               munmap(mem->mappings[i].addr, mem->mappings[i].length);
+
+       free(mem);
+}
+
+void
+rte_vhost_add_mapping(struct virtio_memory *mem, void *addr, size_t length)
+{
+       struct virtio_memory_mapping *m = &mem->mappings[mem->nmappings++];
+
+       m->addr = addr;
+       m->length = length;
+}
diff --git a/lib/librte_vhost/virtio-net.h b/lib/librte_vhost/virtio-net.h
index 75fb57e..a2135b0 100644
--- a/lib/librte_vhost/virtio-net.h
+++ b/lib/librte_vhost/virtio-net.h
@@ -40,4 +40,7 @@
 struct virtio_net_device_ops const *notify_ops;
 struct virtio_net *get_device(struct vhost_device_ctx ctx);

+void rte_vhost_free_mem(struct virtio_memory *mem);
+void rte_vhost_add_mapping(struct virtio_memory *mem, void *addr, size_t 
length);
+
 #endif
-- 
1.9.1

Reply via email to