When the open_device() op is called the container_users is incremented and
held incremented until close_device(). Thus, so long as drivers call
functions within their open_device()/close_device() region they do not
need to worry about the container_users.

These functions can all only be called between open_device() and
close_device():

  vfio_pin_pages()
  vfio_unpin_pages()
  vfio_dma_rw()
  vfio_register_notifier()
  vfio_unregister_notifier()

Eliminate the calls to vfio_group_add_container_user() and add
vfio_assert_device_open() to detect driver mis-use. This causes the
close_device() op to check device->open_count so always leave it elevated
while calling the op.

Reviewed-by: Christoph Hellwig <h...@lst.de>
Reviewed-by: Kevin Tian <kevin.t...@intel.com>
Signed-off-by: Jason Gunthorpe <j...@nvidia.com>
---
 drivers/vfio/vfio.c | 80 ++++++++++-----------------------------------
 1 file changed, 17 insertions(+), 63 deletions(-)

diff --git a/drivers/vfio/vfio.c b/drivers/vfio/vfio.c
index c651c4805acd59..8bb38941c1dfd8 100644
--- a/drivers/vfio/vfio.c
+++ b/drivers/vfio/vfio.c
@@ -1115,6 +1115,12 @@ static int vfio_group_add_container_user(struct 
vfio_group *group)
 
 static const struct file_operations vfio_device_fops;
 
+/* true if the vfio_device has open_device() called but not close_device() */
+static bool vfio_assert_device_open(struct vfio_device *device)
+{
+       return !WARN_ON_ONCE(!READ_ONCE(device->open_count));
+}
+
 static int vfio_group_get_device_fd(struct vfio_group *group, char *buf)
 {
        struct vfio_device *device;
@@ -1329,8 +1335,10 @@ static int vfio_device_fops_release(struct inode *inode, 
struct file *filep)
        struct vfio_device *device = filep->private_data;
 
        mutex_lock(&device->dev_set->lock);
-       if (!--device->open_count && device->ops->close_device)
+       vfio_assert_device_open(device);
+       if (device->open_count == 1 && device->ops->close_device)
                device->ops->close_device(device);
+       device->open_count--;
        mutex_unlock(&device->dev_set->lock);
 
        module_put(device->dev->driver->owner);
@@ -1897,7 +1905,8 @@ int vfio_pin_pages(struct vfio_device *device, unsigned 
long *user_pfn,
        struct vfio_iommu_driver *driver;
        int ret;
 
-       if (!user_pfn || !phys_pfn || !npage)
+       if (!user_pfn || !phys_pfn || !npage ||
+           !vfio_assert_device_open(device))
                return -EINVAL;
 
        if (npage > VFIO_PIN_PAGES_MAX_ENTRIES)
@@ -1906,10 +1915,6 @@ int vfio_pin_pages(struct vfio_device *device, unsigned 
long *user_pfn,
        if (group->dev_counter > 1)
                return -EINVAL;
 
-       ret = vfio_group_add_container_user(group);
-       if (ret)
-               return ret;
-
        container = group->container;
        driver = container->iommu_driver;
        if (likely(driver && driver->ops->pin_pages))
@@ -1919,8 +1924,6 @@ int vfio_pin_pages(struct vfio_device *device, unsigned 
long *user_pfn,
        else
                ret = -ENOTTY;
 
-       vfio_group_try_dissolve_container(group);
-
        return ret;
 }
 EXPORT_SYMBOL(vfio_pin_pages);
@@ -1941,16 +1944,12 @@ int vfio_unpin_pages(struct vfio_device *device, 
unsigned long *user_pfn,
        struct vfio_iommu_driver *driver;
        int ret;
 
-       if (!user_pfn || !npage)
+       if (!user_pfn || !npage || !vfio_assert_device_open(device))
                return -EINVAL;
 
        if (npage > VFIO_PIN_PAGES_MAX_ENTRIES)
                return -E2BIG;
 
-       ret = vfio_group_add_container_user(device->group);
-       if (ret)
-               return ret;
-
        container = device->group->container;
        driver = container->iommu_driver;
        if (likely(driver && driver->ops->unpin_pages))
@@ -1959,8 +1958,6 @@ int vfio_unpin_pages(struct vfio_device *device, unsigned 
long *user_pfn,
        else
                ret = -ENOTTY;
 
-       vfio_group_try_dissolve_container(device->group);
-
        return ret;
 }
 EXPORT_SYMBOL(vfio_unpin_pages);
@@ -1989,13 +1986,9 @@ int vfio_dma_rw(struct vfio_device *device, dma_addr_t 
user_iova, void *data,
        struct vfio_iommu_driver *driver;
        int ret = 0;
 
-       if (!data || len <= 0)
+       if (!data || len <= 0 || !vfio_assert_device_open(device))
                return -EINVAL;
 
-       ret = vfio_group_add_container_user(device->group);
-       if (ret)
-               return ret;
-
        container = device->group->container;
        driver = container->iommu_driver;
 
@@ -2004,9 +1997,6 @@ int vfio_dma_rw(struct vfio_device *device, dma_addr_t 
user_iova, void *data,
                                          user_iova, data, len, write);
        else
                ret = -ENOTTY;
-
-       vfio_group_try_dissolve_container(device->group);
-
        return ret;
 }
 EXPORT_SYMBOL(vfio_dma_rw);
@@ -2019,10 +2009,6 @@ static int vfio_register_iommu_notifier(struct 
vfio_group *group,
        struct vfio_iommu_driver *driver;
        int ret;
 
-       ret = vfio_group_add_container_user(group);
-       if (ret)
-               return -EINVAL;
-
        container = group->container;
        driver = container->iommu_driver;
        if (likely(driver && driver->ops->register_notifier))
@@ -2030,9 +2016,6 @@ static int vfio_register_iommu_notifier(struct vfio_group 
*group,
                                                     events, nb);
        else
                ret = -ENOTTY;
-
-       vfio_group_try_dissolve_container(group);
-
        return ret;
 }
 
@@ -2043,10 +2026,6 @@ static int vfio_unregister_iommu_notifier(struct 
vfio_group *group,
        struct vfio_iommu_driver *driver;
        int ret;
 
-       ret = vfio_group_add_container_user(group);
-       if (ret)
-               return -EINVAL;
-
        container = group->container;
        driver = container->iommu_driver;
        if (likely(driver && driver->ops->unregister_notifier))
@@ -2054,9 +2033,6 @@ static int vfio_unregister_iommu_notifier(struct 
vfio_group *group,
                                                       nb);
        else
                ret = -ENOTTY;
-
-       vfio_group_try_dissolve_container(group);
-
        return ret;
 }
 
@@ -2085,10 +2061,6 @@ static int vfio_register_group_notifier(struct 
vfio_group *group,
        if (*events)
                return -EINVAL;
 
-       ret = vfio_group_add_container_user(group);
-       if (ret)
-               return -EINVAL;
-
        ret = blocking_notifier_chain_register(&group->notifier, nb);
 
        /*
@@ -2098,25 +2070,6 @@ static int vfio_register_group_notifier(struct 
vfio_group *group,
        if (!ret && set_kvm && group->kvm)
                blocking_notifier_call_chain(&group->notifier,
                                        VFIO_GROUP_NOTIFY_SET_KVM, group->kvm);
-
-       vfio_group_try_dissolve_container(group);
-
-       return ret;
-}
-
-static int vfio_unregister_group_notifier(struct vfio_group *group,
-                                        struct notifier_block *nb)
-{
-       int ret;
-
-       ret = vfio_group_add_container_user(group);
-       if (ret)
-               return -EINVAL;
-
-       ret = blocking_notifier_chain_unregister(&group->notifier, nb);
-
-       vfio_group_try_dissolve_container(group);
-
        return ret;
 }
 
@@ -2127,7 +2080,8 @@ int vfio_register_notifier(struct vfio_device *device,
        struct vfio_group *group = device->group;
        int ret;
 
-       if (!nb || !events || (*events == 0))
+       if (!nb || !events || (*events == 0) ||
+           !vfio_assert_device_open(device))
                return -EINVAL;
 
        switch (type) {
@@ -2151,7 +2105,7 @@ int vfio_unregister_notifier(struct vfio_device *device,
        struct vfio_group *group = device->group;
        int ret;
 
-       if (!nb)
+       if (!nb || !vfio_assert_device_open(device))
                return -EINVAL;
 
        switch (type) {
@@ -2159,7 +2113,7 @@ int vfio_unregister_notifier(struct vfio_device *device,
                ret = vfio_unregister_iommu_notifier(group, nb);
                break;
        case VFIO_GROUP_NOTIFY:
-               ret = vfio_unregister_group_notifier(group, nb);
+               ret = blocking_notifier_chain_unregister(&group->notifier, nb);
                break;
        default:
                ret = -EINVAL;
-- 
2.36.0

Reply via email to