With a vIOMMU object, use space can flush any IOMMU related cache that can
be directed via a vIOMMU object. It is similar to the IOMMU_HWPT_INVALIDATE
uAPI, but can cover a wider range than IOTLB, e.g. device/desciprtor cache.

Allow hwpt_id of the iommu_hwpt_invalidate structure to carry a viommu_id,
and reuse the IOMMU_HWPT_INVALIDATE uAPI for vIOMMU invalidations. Drivers
can define different structures for vIOMMU invalidations v.s. HWPT ones.

Since both the HWPT-based and vIOMMU-based invalidation pathways check own
cache invalidation op, remove the WARN_ON_ONCE in the allocator.

Update the uAPI, kdoc, and selftest case accordingly.

Reviewed-by: Jason Gunthorpe <j...@nvidia.com>
Reviewed-by: Kevin Tian <kevin.t...@intel.com>
Signed-off-by: Nicolin Chen <nicol...@nvidia.com>
---
 include/uapi/linux/iommufd.h            |  9 ++++--
 drivers/iommu/iommufd/hw_pagetable.c    | 40 +++++++++++++++++++------
 tools/testing/selftests/iommu/iommufd.c |  4 +--
 3 files changed, 39 insertions(+), 14 deletions(-)

diff --git a/include/uapi/linux/iommufd.h b/include/uapi/linux/iommufd.h
index 9b5236004b8e..badb41c5bfa4 100644
--- a/include/uapi/linux/iommufd.h
+++ b/include/uapi/linux/iommufd.h
@@ -700,7 +700,7 @@ struct iommu_hwpt_vtd_s1_invalidate {
 /**
  * struct iommu_hwpt_invalidate - ioctl(IOMMU_HWPT_INVALIDATE)
  * @size: sizeof(struct iommu_hwpt_invalidate)
- * @hwpt_id: ID of a nested HWPT for cache invalidation
+ * @hwpt_id: ID of a nested HWPT or a vIOMMU, for cache invalidation
  * @data_uptr: User pointer to an array of driver-specific cache invalidation
  *             data.
  * @data_type: One of enum iommu_hwpt_invalidate_data_type, defining the data
@@ -711,8 +711,11 @@ struct iommu_hwpt_vtd_s1_invalidate {
  *             Output the number of requests successfully handled by kernel.
  * @__reserved: Must be 0.
  *
- * Invalidate the iommu cache for user-managed page table. Modifications on a
- * user-managed page table should be followed by this operation to sync cache.
+ * Invalidate iommu cache for user-managed page table or vIOMMU. Modifications
+ * on a user-managed page table should be followed by this operation, if a HWPT
+ * is passed in via @hwpt_id. Other caches, such as device cache or descriptor
+ * cache can be flushed if a vIOMMU is passed in via the @hwpt_id field.
+ *
  * Each ioctl can support one or more cache invalidation requests in the array
  * that has a total size of @entry_len * @entry_num.
  *
diff --git a/drivers/iommu/iommufd/hw_pagetable.c 
b/drivers/iommu/iommufd/hw_pagetable.c
index 982bf4a35a2b..702057655a81 100644
--- a/drivers/iommu/iommufd/hw_pagetable.c
+++ b/drivers/iommu/iommufd/hw_pagetable.c
@@ -251,8 +251,7 @@ iommufd_hwpt_nested_alloc(struct iommufd_ctx *ictx,
        }
        hwpt->domain->owner = ops;
 
-       if (WARN_ON_ONCE(hwpt->domain->type != IOMMU_DOMAIN_NESTED ||
-                        !hwpt->domain->ops->cache_invalidate_user)) {
+       if (WARN_ON_ONCE(hwpt->domain->type != IOMMU_DOMAIN_NESTED)) {
                rc = -EINVAL;
                goto out_abort;
        }
@@ -483,7 +482,7 @@ int iommufd_hwpt_invalidate(struct iommufd_ucmd *ucmd)
                .entry_len = cmd->entry_len,
                .entry_num = cmd->entry_num,
        };
-       struct iommufd_hw_pagetable *hwpt;
+       struct iommufd_object *pt_obj;
        u32 done_num = 0;
        int rc;
 
@@ -497,17 +496,40 @@ int iommufd_hwpt_invalidate(struct iommufd_ucmd *ucmd)
                goto out;
        }
 
-       hwpt = iommufd_get_hwpt_nested(ucmd, cmd->hwpt_id);
-       if (IS_ERR(hwpt)) {
-               rc = PTR_ERR(hwpt);
+       pt_obj = iommufd_get_object(ucmd->ictx, cmd->hwpt_id, IOMMUFD_OBJ_ANY);
+       if (IS_ERR(pt_obj)) {
+               rc = PTR_ERR(pt_obj);
                goto out;
        }
+       if (pt_obj->type == IOMMUFD_OBJ_HWPT_NESTED) {
+               struct iommufd_hw_pagetable *hwpt =
+                       container_of(pt_obj, struct iommufd_hw_pagetable, obj);
+
+               if (!hwpt->domain->ops ||
+                   !hwpt->domain->ops->cache_invalidate_user) {
+                       rc = -EOPNOTSUPP;
+                       goto out_put_pt;
+               }
+               rc = hwpt->domain->ops->cache_invalidate_user(hwpt->domain,
+                                                             &data_array);
+       } else if (pt_obj->type == IOMMUFD_OBJ_VIOMMU) {
+               struct iommufd_viommu *viommu =
+                       container_of(pt_obj, struct iommufd_viommu, obj);
+
+               if (!viommu->ops || !viommu->ops->cache_invalidate) {
+                       rc = -EOPNOTSUPP;
+                       goto out_put_pt;
+               }
+               rc = viommu->ops->cache_invalidate(viommu, &data_array);
+       } else {
+               rc = -EINVAL;
+               goto out_put_pt;
+       }
 
-       rc = hwpt->domain->ops->cache_invalidate_user(hwpt->domain,
-                                                     &data_array);
        done_num = data_array.entry_num;
 
-       iommufd_put_object(ucmd->ictx, &hwpt->obj);
+out_put_pt:
+       iommufd_put_object(ucmd->ictx, pt_obj);
 out:
        cmd->entry_num = done_num;
        if (iommufd_ucmd_respond(ucmd, sizeof(*cmd)))
diff --git a/tools/testing/selftests/iommu/iommufd.c 
b/tools/testing/selftests/iommu/iommufd.c
index f3cb628753c9..8cb3e835ca97 100644
--- a/tools/testing/selftests/iommu/iommufd.c
+++ b/tools/testing/selftests/iommu/iommufd.c
@@ -367,9 +367,9 @@ TEST_F(iommufd_ioas, alloc_hwpt_nested)
                EXPECT_ERRNO(EBUSY,
                             _test_ioctl_destroy(self->fd, parent_hwpt_id));
 
-               /* hwpt_invalidate only supports a user-managed hwpt (nested) */
+               /* hwpt_invalidate does not support a parent hwpt */
                num_inv = 1;
-               test_err_hwpt_invalidate(ENOENT, parent_hwpt_id, inv_reqs,
+               test_err_hwpt_invalidate(EINVAL, parent_hwpt_id, inv_reqs,
                                         IOMMU_HWPT_INVALIDATE_DATA_SELFTEST,
                                         sizeof(*inv_reqs), &num_inv);
                assert(!num_inv);
-- 
2.43.0


Reply via email to