On Thu, Feb 27, 2025 at 08:20:01PM -0400, Jason Gunthorpe wrote:
> virtio has the complication that it sometimes wants to return a paging
> domain for IDENTITY which makes this conversion a little different than
> other drivers.
> 
> Add a viommu_domain_alloc_paging() that combines viommu_domain_alloc() and
> viommu_domain_finalise() to always return a fully initialized and
> finalized paging domain.
> 
> Use viommu_domain_alloc_identity() to implement the special non-bypass
> IDENTITY flow by calling viommu_domain_alloc_paging() then
> viommu_domain_map_identity().
> 
> Remove support for deferred finalize and the vdomain->mutex.
> 
> Remove core support for domain_alloc() IDENTITY as virtio was the last
> driver using it.
> 
> Signed-off-by: Jason Gunthorpe <j...@nvidia.com>

Reviewed-by: Jean-Philippe Brucker <jean-phili...@linaro.org>

And my tests still pass (after fixing the build issue on patch 2)

> ---
>  drivers/iommu/iommu.c        |   6 --
>  drivers/iommu/virtio-iommu.c | 114 ++++++++++++++++-------------------
>  2 files changed, 53 insertions(+), 67 deletions(-)
> 
> diff --git a/drivers/iommu/iommu.c b/drivers/iommu/iommu.c
> index ee33d26dfcd40d..73a05b34de4768 100644
> --- a/drivers/iommu/iommu.c
> +++ b/drivers/iommu/iommu.c
> @@ -1599,12 +1599,6 @@ static struct iommu_domain 
> *__iommu_alloc_identity_domain(struct device *dev)
>               domain = ops->domain_alloc_identity(dev);
>               if (IS_ERR(domain))
>                       return domain;
> -     } else if (ops->domain_alloc) {
> -             domain = ops->domain_alloc(IOMMU_DOMAIN_IDENTITY);
> -             if (!domain)
> -                     return ERR_PTR(-ENOMEM);
> -             if (IS_ERR(domain))
> -                     return domain;
>       } else {
>               return ERR_PTR(-EOPNOTSUPP);
>       }
> diff --git a/drivers/iommu/virtio-iommu.c b/drivers/iommu/virtio-iommu.c
> index 2287b89967067d..cdb8034ece9bde 100644
> --- a/drivers/iommu/virtio-iommu.c
> +++ b/drivers/iommu/virtio-iommu.c
> @@ -63,7 +63,6 @@ struct viommu_mapping {
>  struct viommu_domain {
>       struct iommu_domain             domain;
>       struct viommu_dev               *viommu;
> -     struct mutex                    mutex; /* protects viommu pointer */
>       unsigned int                    id;
>       u32                             map_flags;
>  
> @@ -97,6 +96,8 @@ struct viommu_event {
>       };
>  };
>  
> +static struct viommu_domain viommu_identity_domain;
> +
>  #define to_viommu_domain(domain)     \
>       container_of(domain, struct viommu_domain, domain)
>  
> @@ -653,65 +654,45 @@ static void viommu_event_handler(struct virtqueue *vq)
>  
>  /* IOMMU API */
>  
> -static struct iommu_domain *viommu_domain_alloc(unsigned type)
> +static struct iommu_domain *viommu_domain_alloc_paging(struct device *dev)
>  {
> -     struct viommu_domain *vdomain;
> -
> -     if (type != IOMMU_DOMAIN_UNMANAGED &&
> -         type != IOMMU_DOMAIN_DMA &&
> -         type != IOMMU_DOMAIN_IDENTITY)
> -             return NULL;
> -
> -     vdomain = kzalloc(sizeof(*vdomain), GFP_KERNEL);
> -     if (!vdomain)
> -             return NULL;
> -
> -     mutex_init(&vdomain->mutex);
> -     spin_lock_init(&vdomain->mappings_lock);
> -     vdomain->mappings = RB_ROOT_CACHED;
> -
> -     return &vdomain->domain;
> -}
> -
> -static int viommu_domain_finalise(struct viommu_endpoint *vdev,
> -                               struct iommu_domain *domain)
> -{
> -     int ret;
> -     unsigned long viommu_page_size;
> +     struct viommu_endpoint *vdev = dev_iommu_priv_get(dev);
>       struct viommu_dev *viommu = vdev->viommu;
> -     struct viommu_domain *vdomain = to_viommu_domain(domain);
> +     unsigned long viommu_page_size;
> +     struct viommu_domain *vdomain;
> +     int ret;
>  
>       viommu_page_size = 1UL << __ffs(viommu->pgsize_bitmap);
>       if (viommu_page_size > PAGE_SIZE) {
>               dev_err(vdev->dev,
>                       "granule 0x%lx larger than system page size 0x%lx\n",
>                       viommu_page_size, PAGE_SIZE);
> -             return -ENODEV;
> +             return ERR_PTR(-ENODEV);
>       }
>  
> +     vdomain = kzalloc(sizeof(*vdomain), GFP_KERNEL);
> +     if (!vdomain)
> +             return ERR_PTR(-ENOMEM);
> +
> +     spin_lock_init(&vdomain->mappings_lock);
> +     vdomain->mappings = RB_ROOT_CACHED;
> +
>       ret = ida_alloc_range(&viommu->domain_ids, viommu->first_domain,
>                             viommu->last_domain, GFP_KERNEL);
> -     if (ret < 0)
> -             return ret;
> -
> -     vdomain->id             = (unsigned int)ret;
> -
> -     domain->pgsize_bitmap   = viommu->pgsize_bitmap;
> -     domain->geometry        = viommu->geometry;
> -
> -     vdomain->map_flags      = viommu->map_flags;
> -     vdomain->viommu         = viommu;
> -
> -     if (domain->type == IOMMU_DOMAIN_IDENTITY) {
> -             ret = viommu_domain_map_identity(vdev, vdomain);
> -             if (ret) {
> -                     ida_free(&viommu->domain_ids, vdomain->id);
> -                     vdomain->viommu = NULL;
> -                     return ret;
> -             }
> +     if (ret < 0) {
> +             kfree(vdomain);
> +             return ERR_PTR(ret);
>       }
>  
> -     return 0;
> +     vdomain->id = (unsigned int)ret;
> +
> +     vdomain->domain.pgsize_bitmap = viommu->pgsize_bitmap;
> +     vdomain->domain.geometry = viommu->geometry;
> +
> +     vdomain->map_flags = viommu->map_flags;
> +     vdomain->viommu = viommu;
> +
> +     return &vdomain->domain;
>  }
>  
>  static void viommu_domain_free(struct iommu_domain *domain)
> @@ -727,6 +708,28 @@ static void viommu_domain_free(struct iommu_domain 
> *domain)
>       kfree(vdomain);
>  }
>  
> +static struct iommu_domain *viommu_domain_alloc_identity(struct device *dev)
> +{
> +     struct viommu_endpoint *vdev = dev_iommu_priv_get(dev);
> +     struct iommu_domain *domain;
> +     int ret;
> +
> +     if (virtio_has_feature(vdev->viommu->vdev,
> +                            VIRTIO_IOMMU_F_BYPASS_CONFIG))
> +             return &viommu_identity_domain.domain;
> +
> +     domain = viommu_domain_alloc_paging(dev);
> +     if (IS_ERR(domain))
> +             return domain;
> +
> +     ret = viommu_domain_map_identity(vdev, to_viommu_domain(domain));
> +     if (ret) {
> +             viommu_domain_free(domain);
> +             return ERR_PTR(ret);
> +     }
> +     return domain;
> +}
> +
>  static int viommu_attach_dev(struct iommu_domain *domain, struct device *dev)
>  {
>       int ret = 0;
> @@ -734,20 +737,8 @@ static int viommu_attach_dev(struct iommu_domain 
> *domain, struct device *dev)
>       struct viommu_endpoint *vdev = dev_iommu_priv_get(dev);
>       struct viommu_domain *vdomain = to_viommu_domain(domain);
>  
> -     mutex_lock(&vdomain->mutex);
> -     if (!vdomain->viommu) {
> -             /*
> -              * Properly initialize the domain now that we know which viommu
> -              * owns it.
> -              */
> -             ret = viommu_domain_finalise(vdev, domain);
> -     } else if (vdomain->viommu != vdev->viommu) {
> -             ret = -EINVAL;
> -     }
> -     mutex_unlock(&vdomain->mutex);
> -
> -     if (ret)
> -             return ret;
> +     if (vdomain->viommu != vdev->viommu)
> +             return -EINVAL;
>  
>       /*
>        * In the virtio-iommu device, when attaching the endpoint to a new
> @@ -1097,7 +1088,8 @@ static bool viommu_capable(struct device *dev, enum 
> iommu_cap cap)
>  
>  static const struct iommu_ops viommu_ops = {
>       .capable                = viommu_capable,
> -     .domain_alloc           = viommu_domain_alloc,
> +     .domain_alloc_identity  = viommu_domain_alloc_identity,
> +     .domain_alloc_paging    = viommu_domain_alloc_paging,
>       .probe_device           = viommu_probe_device,
>       .release_device         = viommu_release_device,
>       .device_group           = viommu_device_group,
> -- 
> 2.43.0
> 
> 

Reply via email to