From: Lu Baolu <baolu...@linux.intel.com>

Provide a high-level API to allow replacements of one domain with
another for specific pasid of a device. This is similar to
iommu_group_replace_domain() and it is also expected to be used
only by IOMMUFD.

Signed-off-by: Lu Baolu <baolu...@linux.intel.com>
Signed-off-by: Yi Liu <yi.l....@intel.com>
---
 drivers/iommu/iommu-priv.h |  2 +
 drivers/iommu/iommu.c      | 82 +++++++++++++++++++++++++++++++-------
 2 files changed, 70 insertions(+), 14 deletions(-)

diff --git a/drivers/iommu/iommu-priv.h b/drivers/iommu/iommu-priv.h
index 2024a2313348..5c32637f6325 100644
--- a/drivers/iommu/iommu-priv.h
+++ b/drivers/iommu/iommu-priv.h
@@ -19,6 +19,8 @@ static inline const struct iommu_ops *dev_iommu_ops(struct 
device *dev)
 
 int iommu_group_replace_domain(struct iommu_group *group,
                               struct iommu_domain *new_domain);
+int iommu_replace_device_pasid(struct iommu_domain *domain,
+                              struct device *dev, ioasid_t pasid);
 
 int iommu_device_register_bus(struct iommu_device *iommu,
                              const struct iommu_ops *ops, struct bus_type *bus,
diff --git a/drivers/iommu/iommu.c b/drivers/iommu/iommu.c
index 9d573e971aff..ec213ebd5ecc 100644
--- a/drivers/iommu/iommu.c
+++ b/drivers/iommu/iommu.c
@@ -3430,6 +3430,27 @@ static void __iommu_remove_group_pasid(struct 
iommu_group *group,
        }
 }
 
+static int __iommu_group_attach_pasid(struct iommu_domain *domain,
+                                     struct iommu_group *group, ioasid_t pasid)
+{
+       void *curr;
+       int ret;
+
+       lockdep_assert_held(&group->mutex);
+
+       curr = xa_cmpxchg(&group->pasid_array, pasid, NULL, domain, GFP_KERNEL);
+       if (curr)
+               return xa_err(curr) ? : -EBUSY;
+
+       ret = __iommu_set_group_pasid(domain, group, pasid);
+       if (ret) {
+               __iommu_remove_group_pasid(group, pasid);
+               xa_erase(&group->pasid_array, pasid);
+       }
+
+       return ret;
+}
+
 /*
  * iommu_attach_device_pasid() - Attach a domain to pasid of device
  * @domain: the iommu domain.
@@ -3453,19 +3474,9 @@ int iommu_attach_device_pasid(struct iommu_domain 
*domain,
                return -ENODEV;
 
        mutex_lock(&group->mutex);
-       curr = xa_cmpxchg(&group->pasid_array, pasid, NULL, domain, GFP_KERNEL);
-       if (curr) {
-               ret = xa_err(curr) ? : -EBUSY;
-               goto out_unlock;
-       }
-
-       ret = __iommu_set_group_pasid(domain, group, pasid);
-       if (ret) {
-               __iommu_remove_group_pasid(group, pasid);
-               xa_erase(&group->pasid_array, pasid);
-       }
-out_unlock:
+       ret = __iommu_group_attach_pasid(domain, group, pasid);
        mutex_unlock(&group->mutex);
+
        return ret;
 }
 EXPORT_SYMBOL_GPL(iommu_attach_device_pasid);
@@ -3479,8 +3490,8 @@ EXPORT_SYMBOL_GPL(iommu_attach_device_pasid);
  * The @domain must have been attached to @pasid of the @dev with
  * iommu_attach_device_pasid().
  */
-void iommu_detach_device_pasid(struct iommu_domain *domain, struct device *dev,
-                              ioasid_t pasid)
+void iommu_detach_device_pasid(struct iommu_domain *domain,
+                              struct device *dev, ioasid_t pasid)
 {
        /* Caller must be a probed driver on dev */
        struct iommu_group *group = dev->iommu_group;
@@ -3492,6 +3503,49 @@ void iommu_detach_device_pasid(struct iommu_domain 
*domain, struct device *dev,
 }
 EXPORT_SYMBOL_GPL(iommu_detach_device_pasid);
 
+/**
+ * iommu_replace_device_pasid - replace the domain that a pasid is attached to
+ * @domain: new IOMMU domain to replace with
+ * @dev: the physical device
+ * @pasid: pasid that will be attached to the new domain
+ *
+ * This API allows the pasid to switch domains. Return 0 on success, or an
+ * error. The pasid will roll back to use the old domain if failure. The
+ * caller could call iommu_detach_device_pasid() before free the old domain
+ * in order to avoid use-after-free case.
+ */
+int iommu_replace_device_pasid(struct iommu_domain *domain,
+                              struct device *dev, ioasid_t pasid)
+{
+       struct iommu_group *group = dev->iommu_group;
+       struct iommu_domain *old_domain;
+       int ret;
+
+       if (!domain)
+               return -EINVAL;
+
+       if (!group)
+               return -ENODEV;
+
+       mutex_lock(&group->mutex);
+       __iommu_remove_group_pasid(group, pasid);
+       old_domain = xa_erase(&group->pasid_array, pasid);
+       ret = __iommu_group_attach_pasid(domain, group, pasid);
+       if (ret)
+               goto err_rollback;
+       mutex_unlock(&group->mutex);
+
+       return 0;
+
+err_rollback:
+       if (old_domain)
+               __iommu_group_attach_pasid(old_domain, group, pasid);
+       mutex_unlock(&group->mutex);
+
+       return ret;
+}
+EXPORT_SYMBOL_NS_GPL(iommu_replace_device_pasid, IOMMUFD_INTERNAL);
+
 /*
  * iommu_get_domain_for_dev_pasid() - Retrieve domain for @pasid of @dev
  * @dev: the queried device
-- 
2.34.1


Reply via email to