virtio has the complication that it sometimes wants to return a paging
domain for IDENTITY which makes this conversion a little different than
other drivers.

Add a viommu_domain_alloc_paging() that combines viommu_domain_alloc() and
viommu_domain_finalise() to always return a fully initialized and
finalized paging domain.

Use viommu_domain_alloc_identity() to implement the special non-bypass
IDENTITY flow by calling viommu_domain_alloc_paging() then
viommu_domain_map_identity().

Remove support for deferred finalize and the vdomain->mutex.

Remove core support for domain_alloc() IDENTITY as virtio was the last
driver using it.

Reviewed-by: Jean-Philippe Brucker <jean-phili...@linaro.org>
Signed-off-by: Jason Gunthorpe <j...@nvidia.com>
---
 drivers/iommu/iommu.c        |   6 --
 drivers/iommu/virtio-iommu.c | 121 +++++++++++++++--------------------
 2 files changed, 53 insertions(+), 74 deletions(-)

diff --git a/drivers/iommu/iommu.c b/drivers/iommu/iommu.c
index ee33d26dfcd40d..73a05b34de4768 100644
--- a/drivers/iommu/iommu.c
+++ b/drivers/iommu/iommu.c
@@ -1599,12 +1599,6 @@ static struct iommu_domain 
*__iommu_alloc_identity_domain(struct device *dev)
                domain = ops->domain_alloc_identity(dev);
                if (IS_ERR(domain))
                        return domain;
-       } else if (ops->domain_alloc) {
-               domain = ops->domain_alloc(IOMMU_DOMAIN_IDENTITY);
-               if (!domain)
-                       return ERR_PTR(-ENOMEM);
-               if (IS_ERR(domain))
-                       return domain;
        } else {
                return ERR_PTR(-EOPNOTSUPP);
        }
diff --git a/drivers/iommu/virtio-iommu.c b/drivers/iommu/virtio-iommu.c
index 55a2188197c621..ecd41fb03e5a51 100644
--- a/drivers/iommu/virtio-iommu.c
+++ b/drivers/iommu/virtio-iommu.c
@@ -63,7 +63,6 @@ struct viommu_mapping {
 struct viommu_domain {
        struct iommu_domain             domain;
        struct viommu_dev               *viommu;
-       struct mutex                    mutex; /* protects viommu pointer */
        unsigned int                    id;
        u32                             map_flags;
 
@@ -97,6 +96,8 @@ struct viommu_event {
        };
 };
 
+static struct viommu_domain viommu_identity_domain;
+
 #define to_viommu_domain(domain)       \
        container_of(domain, struct viommu_domain, domain)
 
@@ -653,65 +654,45 @@ static void viommu_event_handler(struct virtqueue *vq)
 
 /* IOMMU API */
 
-static struct iommu_domain *viommu_domain_alloc(unsigned type)
+static struct iommu_domain *viommu_domain_alloc_paging(struct device *dev)
 {
-       struct viommu_domain *vdomain;
-
-       if (type != IOMMU_DOMAIN_UNMANAGED &&
-           type != IOMMU_DOMAIN_DMA &&
-           type != IOMMU_DOMAIN_IDENTITY)
-               return NULL;
-
-       vdomain = kzalloc(sizeof(*vdomain), GFP_KERNEL);
-       if (!vdomain)
-               return NULL;
-
-       mutex_init(&vdomain->mutex);
-       spin_lock_init(&vdomain->mappings_lock);
-       vdomain->mappings = RB_ROOT_CACHED;
-
-       return &vdomain->domain;
-}
-
-static int viommu_domain_finalise(struct viommu_endpoint *vdev,
-                                 struct iommu_domain *domain)
-{
-       int ret;
-       unsigned long viommu_page_size;
+       struct viommu_endpoint *vdev = dev_iommu_priv_get(dev);
        struct viommu_dev *viommu = vdev->viommu;
-       struct viommu_domain *vdomain = to_viommu_domain(domain);
+       unsigned long viommu_page_size;
+       struct viommu_domain *vdomain;
+       int ret;
 
        viommu_page_size = 1UL << __ffs(viommu->pgsize_bitmap);
        if (viommu_page_size > PAGE_SIZE) {
                dev_err(vdev->dev,
                        "granule 0x%lx larger than system page size 0x%lx\n",
                        viommu_page_size, PAGE_SIZE);
-               return -ENODEV;
+               return ERR_PTR(-ENODEV);
        }
 
+       vdomain = kzalloc(sizeof(*vdomain), GFP_KERNEL);
+       if (!vdomain)
+               return ERR_PTR(-ENOMEM);
+
+       spin_lock_init(&vdomain->mappings_lock);
+       vdomain->mappings = RB_ROOT_CACHED;
+
        ret = ida_alloc_range(&viommu->domain_ids, viommu->first_domain,
                              viommu->last_domain, GFP_KERNEL);
-       if (ret < 0)
-               return ret;
-
-       vdomain->id             = (unsigned int)ret;
-
-       domain->pgsize_bitmap   = viommu->pgsize_bitmap;
-       domain->geometry        = viommu->geometry;
-
-       vdomain->map_flags      = viommu->map_flags;
-       vdomain->viommu         = viommu;
-
-       if (domain->type == IOMMU_DOMAIN_IDENTITY) {
-               ret = viommu_domain_map_identity(vdev, vdomain);
-               if (ret) {
-                       ida_free(&viommu->domain_ids, vdomain->id);
-                       vdomain->viommu = NULL;
-                       return ret;
-               }
+       if (ret < 0) {
+               kfree(vdomain);
+               return ERR_PTR(ret);
        }
 
-       return 0;
+       vdomain->id = (unsigned int)ret;
+
+       vdomain->domain.pgsize_bitmap = viommu->pgsize_bitmap;
+       vdomain->domain.geometry = viommu->geometry;
+
+       vdomain->map_flags = viommu->map_flags;
+       vdomain->viommu = viommu;
+
+       return &vdomain->domain;
 }
 
 static void viommu_domain_free(struct iommu_domain *domain)
@@ -727,6 +708,28 @@ static void viommu_domain_free(struct iommu_domain *domain)
        kfree(vdomain);
 }
 
+static struct iommu_domain *viommu_domain_alloc_identity(struct device *dev)
+{
+       struct viommu_endpoint *vdev = dev_iommu_priv_get(dev);
+       struct iommu_domain *domain;
+       int ret;
+
+       if (virtio_has_feature(vdev->viommu->vdev,
+                              VIRTIO_IOMMU_F_BYPASS_CONFIG))
+               return &viommu_identity_domain.domain;
+
+       domain = viommu_domain_alloc_paging(dev);
+       if (IS_ERR(domain))
+               return domain;
+
+       ret = viommu_domain_map_identity(vdev, to_viommu_domain(domain));
+       if (ret) {
+               viommu_domain_free(domain);
+               return ERR_PTR(ret);
+       }
+       return domain;
+}
+
 static int viommu_attach_dev(struct iommu_domain *domain, struct device *dev)
 {
        int ret = 0;
@@ -734,20 +737,8 @@ static int viommu_attach_dev(struct iommu_domain *domain, 
struct device *dev)
        struct viommu_endpoint *vdev = dev_iommu_priv_get(dev);
        struct viommu_domain *vdomain = to_viommu_domain(domain);
 
-       mutex_lock(&vdomain->mutex);
-       if (!vdomain->viommu) {
-               /*
-                * Properly initialize the domain now that we know which viommu
-                * owns it.
-                */
-               ret = viommu_domain_finalise(vdev, domain);
-       } else if (vdomain->viommu != vdev->viommu) {
-               ret = -EINVAL;
-       }
-       mutex_unlock(&vdomain->mutex);
-
-       if (ret)
-               return ret;
+       if (vdomain->viommu != vdev->viommu)
+               return -EINVAL;
 
        /*
         * In the virtio-iommu device, when attaching the endpoint to a new
@@ -1096,9 +1087,9 @@ static bool viommu_capable(struct device *dev, enum 
iommu_cap cap)
 }
 
 static struct iommu_ops viommu_ops = {
-       .identity_domain        = &viommu_identity_domain.domain,
        .capable                = viommu_capable,
-       .domain_alloc           = viommu_domain_alloc,
+       .domain_alloc_identity  = viommu_domain_alloc_identity,
+       .domain_alloc_paging    = viommu_domain_alloc_paging,
        .probe_device           = viommu_probe_device,
        .release_device         = viommu_release_device,
        .device_group           = viommu_device_group,
@@ -1224,12 +1215,6 @@ static int viommu_probe(struct virtio_device *vdev)
        if (virtio_has_feature(viommu->vdev, VIRTIO_IOMMU_F_BYPASS_CONFIG)) {
                viommu->identity_domain_id = viommu->first_domain;
                viommu->first_domain++;
-       } else {
-               /*
-                * Assume the VMM is sensible and it either supports bypass on
-                * all instances or no instances.
-                */
-               viommu_ops.identity_domain = NULL;
        }
 
        viommu_ops.pgsize_bitmap = viommu->pgsize_bitmap;
-- 
2.43.0


Reply via email to