The IOMMU core attaches each device to a default domain on probe(). Then,
every new "attach" operation has a fundamental meaning of two-fold:
 - detach from its currently attached (old) domain
 - attach to a given new domain

Modern IOMMU drivers following this pattern usually want to clean up the
things related to the old domain, so they call iommu_get_domain_for_dev()
to fetch the old domain.

Pass in the old domain pointer from the core to drivers, aligning with the
set_dev_pasid op that passes in already.

Ensure all low-level attach fcuntions in the core can forward the correct
old domain pointer. Thus, rework those functions as well.

Suggested-by: Jason Gunthorpe <j...@nvidia.com>
Signed-off-by: Nicolin Chen <nicol...@nvidia.com>
---
 include/linux/iommu.h                         |  3 +-
 drivers/iommu/amd/iommu.c                     | 11 ++++---
 drivers/iommu/apple-dart.c                    |  9 +++--
 .../arm/arm-smmu-v3/arm-smmu-v3-iommufd.c     |  5 +--
 drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.c   | 33 ++++++++++++-------
 drivers/iommu/arm/arm-smmu/arm-smmu.c         |  9 +++--
 drivers/iommu/arm/arm-smmu/qcom_iommu.c       | 11 ++++---
 drivers/iommu/exynos-iommu.c                  |  6 ++--
 drivers/iommu/fsl_pamu_domain.c               | 12 +++----
 drivers/iommu/intel/iommu.c                   | 10 ++++--
 drivers/iommu/intel/nested.c                  |  2 +-
 drivers/iommu/iommu.c                         | 26 +++++++++------
 drivers/iommu/iommufd/selftest.c              |  2 +-
 drivers/iommu/ipmmu-vmsa.c                    | 10 +++---
 drivers/iommu/msm_iommu.c                     |  8 ++---
 drivers/iommu/mtk_iommu.c                     |  8 ++---
 drivers/iommu/mtk_iommu_v1.c                  |  7 ++--
 drivers/iommu/omap-iommu.c                    | 12 +++----
 drivers/iommu/riscv/iommu.c                   |  9 +++--
 drivers/iommu/rockchip-iommu.c                | 20 ++++++++---
 drivers/iommu/s390-iommu.c                    |  9 +++--
 drivers/iommu/sprd-iommu.c                    |  3 +-
 drivers/iommu/sun50i-iommu.c                  |  8 +++--
 drivers/iommu/tegra-smmu.c                    | 10 +++---
 drivers/iommu/virtio-iommu.c                  |  6 ++--
 25 files changed, 152 insertions(+), 97 deletions(-)

diff --git a/include/linux/iommu.h b/include/linux/iommu.h
index c30d12e16473d..801b2bd9e8d49 100644
--- a/include/linux/iommu.h
+++ b/include/linux/iommu.h
@@ -751,7 +751,8 @@ struct iommu_ops {
  * @free: Release the domain after use.
  */
 struct iommu_domain_ops {
-       int (*attach_dev)(struct iommu_domain *domain, struct device *dev);
+       int (*attach_dev)(struct iommu_domain *domain, struct device *dev,
+                         struct iommu_domain *old);
        int (*set_dev_pasid)(struct iommu_domain *domain, struct device *dev,
                             ioasid_t pasid, struct iommu_domain *old);
 
diff --git a/drivers/iommu/amd/iommu.c b/drivers/iommu/amd/iommu.c
index eb348c63a8d09..8a18a3bfa5a2d 100644
--- a/drivers/iommu/amd/iommu.c
+++ b/drivers/iommu/amd/iommu.c
@@ -69,8 +69,8 @@ int amd_iommu_max_glx_val = -1;
  */
 DEFINE_IDA(pdom_ids);
 
-static int amd_iommu_attach_device(struct iommu_domain *dom,
-                                  struct device *dev);
+static int amd_iommu_attach_device(struct iommu_domain *dom, struct device 
*dev,
+                                  struct iommu_domain *old);
 
 static void set_dte_entry(struct amd_iommu *iommu,
                          struct iommu_dev_data *dev_data);
@@ -2634,7 +2634,8 @@ void amd_iommu_domain_free(struct iommu_domain *dom)
 }
 
 static int blocked_domain_attach_device(struct iommu_domain *domain,
-                                       struct device *dev)
+                                       struct device *dev,
+                                       struct iommu_domain *old)
 {
        struct iommu_dev_data *dev_data = dev_iommu_priv_get(dev);
 
@@ -2692,8 +2693,8 @@ static struct iommu_domain release_domain = {
        }
 };
 
-static int amd_iommu_attach_device(struct iommu_domain *dom,
-                                  struct device *dev)
+static int amd_iommu_attach_device(struct iommu_domain *dom, struct device 
*dev,
+                                  struct iommu_domain *old)
 {
        struct iommu_dev_data *dev_data = dev_iommu_priv_get(dev);
        struct protection_domain *domain = to_pdomain(dom);
diff --git a/drivers/iommu/apple-dart.c b/drivers/iommu/apple-dart.c
index 190f28d766151..88648e051e783 100644
--- a/drivers/iommu/apple-dart.c
+++ b/drivers/iommu/apple-dart.c
@@ -660,7 +660,8 @@ static int apple_dart_domain_add_streams(struct 
apple_dart_domain *domain,
 }
 
 static int apple_dart_attach_dev_paging(struct iommu_domain *domain,
-                                       struct device *dev)
+                                       struct device *dev,
+                                       struct iommu_domain *old)
 {
        int ret, i;
        struct apple_dart_stream_map *stream_map;
@@ -681,7 +682,8 @@ static int apple_dart_attach_dev_paging(struct iommu_domain 
*domain,
 }
 
 static int apple_dart_attach_dev_identity(struct iommu_domain *domain,
-                                         struct device *dev)
+                                         struct device *dev,
+                                         struct iommu_domain *old)
 {
        struct apple_dart_master_cfg *cfg = dev_iommu_priv_get(dev);
        struct apple_dart_stream_map *stream_map;
@@ -705,7 +707,8 @@ static struct iommu_domain apple_dart_identity_domain = {
 };
 
 static int apple_dart_attach_dev_blocked(struct iommu_domain *domain,
-                                        struct device *dev)
+                                        struct device *dev,
+                                        struct iommu_domain *old)
 {
        struct apple_dart_master_cfg *cfg = dev_iommu_priv_get(dev);
        struct apple_dart_stream_map *stream_map;
diff --git a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-iommufd.c 
b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-iommufd.c
index 8cd8929bbfdf8..313201a616991 100644
--- a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-iommufd.c
+++ b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-iommufd.c
@@ -138,14 +138,15 @@ void arm_smmu_master_clear_vmaster(struct arm_smmu_master 
*master)
 }
 
 static int arm_smmu_attach_dev_nested(struct iommu_domain *domain,
-                                     struct device *dev)
+                                     struct device *dev,
+                                     struct iommu_domain *old_domain)
 {
        struct arm_smmu_nested_domain *nested_domain =
                to_smmu_nested_domain(domain);
        struct arm_smmu_master *master = dev_iommu_priv_get(dev);
        struct arm_smmu_attach_state state = {
                .master = master,
-               .old_domain = iommu_get_domain_for_dev(dev),
+               .old_domain = old_domain,
                .ssid = IOMMU_NO_PASID,
        };
        struct arm_smmu_ste ste;
diff --git a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.c 
b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.c
index 1a21d1a2dd454..de02eeb524c15 100644
--- a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.c
+++ b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.c
@@ -3002,7 +3002,8 @@ void arm_smmu_attach_commit(struct arm_smmu_attach_state 
*state)
        arm_smmu_remove_master_domain(master, state->old_domain, state->ssid);
 }
 
-static int arm_smmu_attach_dev(struct iommu_domain *domain, struct device *dev)
+static int arm_smmu_attach_dev(struct iommu_domain *domain, struct device *dev,
+                              struct iommu_domain *old_domain)
 {
        int ret = 0;
        struct arm_smmu_ste target;
@@ -3010,7 +3011,7 @@ static int arm_smmu_attach_dev(struct iommu_domain 
*domain, struct device *dev)
        struct arm_smmu_device *smmu;
        struct arm_smmu_domain *smmu_domain = to_smmu_domain(domain);
        struct arm_smmu_attach_state state = {
-               .old_domain = iommu_get_domain_for_dev(dev),
+               .old_domain = old_domain,
                .ssid = IOMMU_NO_PASID,
        };
        struct arm_smmu_master *master;
@@ -3186,7 +3187,7 @@ static int arm_smmu_blocking_set_dev_pasid(struct 
iommu_domain *new_domain,
 
        /*
         * When the last user of the CD table goes away downgrade the STE back
-        * to a non-cd_table one.
+        * to a non-cd_table one, by re-attaching its sid_domain.
         */
        if (!arm_smmu_ssids_in_use(&master->cd_table)) {
                struct iommu_domain *sid_domain =
@@ -3194,12 +3195,14 @@ static int arm_smmu_blocking_set_dev_pasid(struct 
iommu_domain *new_domain,
 
                if (sid_domain->type == IOMMU_DOMAIN_IDENTITY ||
                    sid_domain->type == IOMMU_DOMAIN_BLOCKED)
-                       sid_domain->ops->attach_dev(sid_domain, dev);
+                       sid_domain->ops->attach_dev(sid_domain, dev,
+                                                   sid_domain);
        }
        return 0;
 }
 
 static void arm_smmu_attach_dev_ste(struct iommu_domain *domain,
+                                   struct iommu_domain *old_domain,
                                    struct device *dev,
                                    struct arm_smmu_ste *ste,
                                    unsigned int s1dss)
@@ -3207,7 +3210,7 @@ static void arm_smmu_attach_dev_ste(struct iommu_domain 
*domain,
        struct arm_smmu_master *master = dev_iommu_priv_get(dev);
        struct arm_smmu_attach_state state = {
                .master = master,
-               .old_domain = iommu_get_domain_for_dev(dev),
+               .old_domain = old_domain,
                .ssid = IOMMU_NO_PASID,
        };
 
@@ -3248,14 +3251,16 @@ static void arm_smmu_attach_dev_ste(struct iommu_domain 
*domain,
 }
 
 static int arm_smmu_attach_dev_identity(struct iommu_domain *domain,
-                                       struct device *dev)
+                                       struct device *dev,
+                                       struct iommu_domain *old_domain)
 {
        struct arm_smmu_ste ste;
        struct arm_smmu_master *master = dev_iommu_priv_get(dev);
 
        arm_smmu_master_clear_vmaster(master);
        arm_smmu_make_bypass_ste(master->smmu, &ste);
-       arm_smmu_attach_dev_ste(domain, dev, &ste, STRTAB_STE_1_S1DSS_BYPASS);
+       arm_smmu_attach_dev_ste(domain, old_domain, dev, &ste,
+                               STRTAB_STE_1_S1DSS_BYPASS);
        return 0;
 }
 
@@ -3269,14 +3274,15 @@ static struct iommu_domain arm_smmu_identity_domain = {
 };
 
 static int arm_smmu_attach_dev_blocked(struct iommu_domain *domain,
-                                       struct device *dev)
+                                      struct device *dev,
+                                      struct iommu_domain *old_domain)
 {
        struct arm_smmu_ste ste;
        struct arm_smmu_master *master = dev_iommu_priv_get(dev);
 
        arm_smmu_master_clear_vmaster(master);
        arm_smmu_make_abort_ste(&ste);
-       arm_smmu_attach_dev_ste(domain, dev, &ste,
+       arm_smmu_attach_dev_ste(domain, old_domain, dev, &ste,
                                STRTAB_STE_1_S1DSS_TERMINATE);
        return 0;
 }
@@ -3292,7 +3298,8 @@ static struct iommu_domain arm_smmu_blocked_domain = {
 };
 
 static int arm_smmu_attach_dev_release(struct iommu_domain *domain,
-                                      struct device *dev)
+                                      struct device *dev,
+                                      struct iommu_domain *old_domain)
 {
        struct arm_smmu_master *master = dev_iommu_priv_get(dev);
 
@@ -3300,9 +3307,11 @@ static int arm_smmu_attach_dev_release(struct 
iommu_domain *domain,
 
        /* Put the STE back to what arm_smmu_init_strtab() sets */
        if (dev->iommu->require_direct)
-               arm_smmu_attach_dev_identity(&arm_smmu_identity_domain, dev);
+               arm_smmu_attach_dev_identity(&arm_smmu_identity_domain, dev,
+                                            old_domain);
        else
-               arm_smmu_attach_dev_blocked(&arm_smmu_blocked_domain, dev);
+               arm_smmu_attach_dev_blocked(&arm_smmu_blocked_domain, dev,
+                                           old_domain);
 
        return 0;
 }
diff --git a/drivers/iommu/arm/arm-smmu/arm-smmu.c 
b/drivers/iommu/arm/arm-smmu/arm-smmu.c
index 4ced4b5bee4df..5e690cf85ec96 100644
--- a/drivers/iommu/arm/arm-smmu/arm-smmu.c
+++ b/drivers/iommu/arm/arm-smmu/arm-smmu.c
@@ -1165,7 +1165,8 @@ static void arm_smmu_master_install_s2crs(struct 
arm_smmu_master_cfg *cfg,
        }
 }
 
-static int arm_smmu_attach_dev(struct iommu_domain *domain, struct device *dev)
+static int arm_smmu_attach_dev(struct iommu_domain *domain, struct device *dev,
+                              struct iommu_domain *old)
 {
        struct arm_smmu_domain *smmu_domain = to_smmu_domain(domain);
        struct iommu_fwspec *fwspec = dev_iommu_fwspec_get(dev);
@@ -1234,7 +1235,8 @@ static int arm_smmu_attach_dev_type(struct device *dev,
 }
 
 static int arm_smmu_attach_dev_identity(struct iommu_domain *domain,
-                                       struct device *dev)
+                                       struct device *dev,
+                                       struct iommu_domain *old)
 {
        return arm_smmu_attach_dev_type(dev, S2CR_TYPE_BYPASS);
 }
@@ -1249,7 +1251,8 @@ static struct iommu_domain arm_smmu_identity_domain = {
 };
 
 static int arm_smmu_attach_dev_blocked(struct iommu_domain *domain,
-                                      struct device *dev)
+                                      struct device *dev,
+                                      struct iommu_domain *old)
 {
        return arm_smmu_attach_dev_type(dev, S2CR_TYPE_FAULT);
 }
diff --git a/drivers/iommu/arm/arm-smmu/qcom_iommu.c 
b/drivers/iommu/arm/arm-smmu/qcom_iommu.c
index c5be95e560317..9222a4a48bb33 100644
--- a/drivers/iommu/arm/arm-smmu/qcom_iommu.c
+++ b/drivers/iommu/arm/arm-smmu/qcom_iommu.c
@@ -359,7 +359,8 @@ static void qcom_iommu_domain_free(struct iommu_domain 
*domain)
        kfree(qcom_domain);
 }
 
-static int qcom_iommu_attach_dev(struct iommu_domain *domain, struct device 
*dev)
+static int qcom_iommu_attach_dev(struct iommu_domain *domain,
+                                struct device *dev, struct iommu_domain *old)
 {
        struct qcom_iommu_dev *qcom_iommu = dev_iommu_priv_get(dev);
        struct qcom_iommu_domain *qcom_domain = to_qcom_iommu_domain(domain);
@@ -388,18 +389,18 @@ static int qcom_iommu_attach_dev(struct iommu_domain 
*domain, struct device *dev
 }
 
 static int qcom_iommu_identity_attach(struct iommu_domain *identity_domain,
-                                     struct device *dev)
+                                     struct device *dev,
+                                     struct iommu_domain *old)
 {
-       struct iommu_domain *domain = iommu_get_domain_for_dev(dev);
        struct qcom_iommu_domain *qcom_domain;
        struct iommu_fwspec *fwspec = dev_iommu_fwspec_get(dev);
        struct qcom_iommu_dev *qcom_iommu = dev_iommu_priv_get(dev);
        unsigned int i;
 
-       if (domain == identity_domain || !domain)
+       if (old == identity_domain || !old)
                return 0;
 
-       qcom_domain = to_qcom_iommu_domain(domain);
+       qcom_domain = to_qcom_iommu_domain(old);
        if (WARN_ON(!qcom_domain->iommu))
                return -EINVAL;
 
diff --git a/drivers/iommu/exynos-iommu.c b/drivers/iommu/exynos-iommu.c
index b6edd178fe25e..b30d2bb87fa96 100644
--- a/drivers/iommu/exynos-iommu.c
+++ b/drivers/iommu/exynos-iommu.c
@@ -984,7 +984,8 @@ static void exynos_iommu_domain_free(struct iommu_domain 
*iommu_domain)
 }
 
 static int exynos_iommu_identity_attach(struct iommu_domain *identity_domain,
-                                       struct device *dev)
+                                       struct device *dev,
+                                       struct iommu_domain *old)
 {
        struct exynos_iommu_owner *owner = dev_iommu_priv_get(dev);
        struct exynos_iommu_domain *domain;
@@ -1035,7 +1036,8 @@ static struct iommu_domain exynos_identity_domain = {
 };
 
 static int exynos_iommu_attach_device(struct iommu_domain *iommu_domain,
-                                  struct device *dev)
+                                     struct device *dev,
+                                     struct iommu_domain *old)
 {
        struct exynos_iommu_domain *domain = to_exynos_domain(iommu_domain);
        struct exynos_iommu_owner *owner = dev_iommu_priv_get(dev);
diff --git a/drivers/iommu/fsl_pamu_domain.c b/drivers/iommu/fsl_pamu_domain.c
index 5f08523f97cb9..9664ef9840d2c 100644
--- a/drivers/iommu/fsl_pamu_domain.c
+++ b/drivers/iommu/fsl_pamu_domain.c
@@ -238,7 +238,7 @@ static int update_domain_stash(struct fsl_dma_domain 
*dma_domain, u32 val)
 }
 
 static int fsl_pamu_attach_device(struct iommu_domain *domain,
-                                 struct device *dev)
+                                 struct device *dev, struct iommu_domain *old)
 {
        struct fsl_dma_domain *dma_domain = to_fsl_dma_domain(domain);
        unsigned long flags;
@@ -298,9 +298,9 @@ static int fsl_pamu_attach_device(struct iommu_domain 
*domain,
  * switches to what looks like BLOCKING.
  */
 static int fsl_pamu_platform_attach(struct iommu_domain *platform_domain,
-                                   struct device *dev)
+                                   struct device *dev,
+                                   struct iommu_domain *old)
 {
-       struct iommu_domain *domain = iommu_get_domain_for_dev(dev);
        struct fsl_dma_domain *dma_domain;
        const u32 *prop;
        int len;
@@ -311,11 +311,11 @@ static int fsl_pamu_platform_attach(struct iommu_domain 
*platform_domain,
         * Hack to keep things working as they always have, only leaving an
         * UNMANAGED domain makes it BLOCKING.
         */
-       if (domain == platform_domain || !domain ||
-           domain->type != IOMMU_DOMAIN_UNMANAGED)
+       if (old == platform_domain || !old ||
+           old->type != IOMMU_DOMAIN_UNMANAGED)
                return 0;
 
-       dma_domain = to_fsl_dma_domain(domain);
+       dma_domain = to_fsl_dma_domain(old);
 
        /*
         * Use LIODN of the PCI controller while detaching a
diff --git a/drivers/iommu/intel/iommu.c b/drivers/iommu/intel/iommu.c
index 9c3ab9d9f69a3..e9fbe9f6cc6cd 100644
--- a/drivers/iommu/intel/iommu.c
+++ b/drivers/iommu/intel/iommu.c
@@ -3225,7 +3225,8 @@ void device_block_translation(struct device *dev)
 }
 
 static int blocking_domain_attach_dev(struct iommu_domain *domain,
-                                     struct device *dev)
+                                     struct device *dev,
+                                     struct iommu_domain *old)
 {
        struct device_domain_info *info = dev_iommu_priv_get(dev);
 
@@ -3532,7 +3533,8 @@ int paging_domain_compatible(struct iommu_domain *domain, 
struct device *dev)
 }
 
 static int intel_iommu_attach_device(struct iommu_domain *domain,
-                                    struct device *dev)
+                                    struct device *dev,
+                                    struct iommu_domain *old)
 {
        int ret;
 
@@ -4396,7 +4398,9 @@ static int device_setup_pass_through(struct device *dev)
                                      context_setup_pass_through_cb, dev);
 }
 
-static int identity_domain_attach_dev(struct iommu_domain *domain, struct 
device *dev)
+static int identity_domain_attach_dev(struct iommu_domain *domain,
+                                     struct device *dev,
+                                     struct iommu_domain *old)
 {
        struct device_domain_info *info = dev_iommu_priv_get(dev);
        struct intel_iommu *iommu = info->iommu;
diff --git a/drivers/iommu/intel/nested.c b/drivers/iommu/intel/nested.c
index 1b6ad9c900a5a..760d7aa2ade84 100644
--- a/drivers/iommu/intel/nested.c
+++ b/drivers/iommu/intel/nested.c
@@ -19,7 +19,7 @@
 #include "pasid.h"
 
 static int intel_nested_attach_dev(struct iommu_domain *domain,
-                                  struct device *dev)
+                                  struct device *dev, struct iommu_domain *old)
 {
        struct device_domain_info *info = dev_iommu_priv_get(dev);
        struct dmar_domain *dmar_domain = to_dmar_domain(domain);
diff --git a/drivers/iommu/iommu.c b/drivers/iommu/iommu.c
index e6a66dacce1b8..ef3fd7bd1b553 100644
--- a/drivers/iommu/iommu.c
+++ b/drivers/iommu/iommu.c
@@ -100,7 +100,7 @@ static int iommu_bus_notifier(struct notifier_block *nb,
                              unsigned long action, void *data);
 static void iommu_release_device(struct device *dev);
 static int __iommu_attach_device(struct iommu_domain *domain,
-                                struct device *dev);
+                                struct device *dev, struct iommu_domain *old);
 static int __iommu_attach_group(struct iommu_domain *domain,
                                struct iommu_group *group);
 static struct iommu_domain *__iommu_paging_domain_alloc_flags(struct device 
*dev,
@@ -114,6 +114,7 @@ enum {
 static int __iommu_device_set_domain(struct iommu_group *group,
                                     struct group_device *gdev,
                                     struct iommu_domain *new_domain,
+                                    struct iommu_domain *old_domain,
                                     unsigned int flags);
 static int __iommu_group_set_domain_internal(struct iommu_group *group,
                                             struct iommu_domain *new_domain,
@@ -517,7 +518,8 @@ static void iommu_deinit_device(struct device *dev)
         * should still avoid touching any hardware configuration either.
         */
        if (!dev->iommu->attach_deferred && ops->release_domain)
-               ops->release_domain->ops->attach_dev(ops->release_domain, dev);
+               ops->release_domain->ops->attach_dev(ops->release_domain, dev,
+                                                    group->domain);
 
        if (ops->release_device)
                ops->release_device(dev);
@@ -602,7 +604,8 @@ static int __iommu_probe_device(struct device *dev, struct 
list_head *group_list
        if (group->default_domain)
                iommu_create_device_direct_mappings(group->default_domain, dev);
        if (group->domain) {
-               ret = __iommu_device_set_domain(group, gdev, group->domain, 0);
+               ret = __iommu_device_set_domain(group, gdev, group->domain,
+                                               NULL, 0);
                if (ret)
                        goto err_remove_gdev;
        } else if (!group->default_domain && !group_list) {
@@ -2089,14 +2092,14 @@ static void __iommu_group_set_core_domain(struct 
iommu_group *group)
 }
 
 static int __iommu_attach_device(struct iommu_domain *domain,
-                                struct device *dev)
+                                struct device *dev, struct iommu_domain *old)
 {
        int ret;
 
        if (unlikely(domain->ops->attach_dev == NULL))
                return -ENODEV;
 
-       ret = domain->ops->attach_dev(domain, dev);
+       ret = domain->ops->attach_dev(domain, dev, old);
        if (ret)
                return ret;
        dev->iommu->attach_deferred = 0;
@@ -2154,7 +2157,7 @@ int iommu_deferred_attach(struct device *dev, struct 
iommu_domain *domain)
 
        guard(mutex)(&dev->iommu_group->mutex);
 
-       return __iommu_attach_device(domain, dev);
+       return __iommu_attach_device(domain, dev, NULL);
 }
 
 void iommu_detach_device(struct iommu_domain *domain, struct device *dev)
@@ -2265,6 +2268,7 @@ EXPORT_SYMBOL_GPL(iommu_attach_group);
 static int __iommu_device_set_domain(struct iommu_group *group,
                                     struct group_device *gdev,
                                     struct iommu_domain *new_domain,
+                                    struct iommu_domain *old_domain,
                                     unsigned int flags)
 {
        struct device *dev = gdev->dev;
@@ -2291,7 +2295,7 @@ static int __iommu_device_set_domain(struct iommu_group 
*group,
                dev->iommu->attach_deferred = 0;
        }
 
-       ret = __iommu_attach_device(new_domain, dev);
+       ret = __iommu_attach_device(new_domain, dev, old_domain);
        if (ret) {
                /*
                 * If we have a blocking domain then try to attach that in hopes
@@ -2301,7 +2305,8 @@ static int __iommu_device_set_domain(struct iommu_group 
*group,
                if ((flags & IOMMU_SET_DOMAIN_MUST_SUCCEED) &&
                    group->blocking_domain &&
                    group->blocking_domain != new_domain)
-                       __iommu_attach_device(group->blocking_domain, dev);
+                       __iommu_attach_device(group->blocking_domain, dev,
+                                             old_domain);
                return ret;
        }
        return 0;
@@ -2347,7 +2352,8 @@ static int __iommu_group_set_domain_internal(struct 
iommu_group *group,
         */
        result = 0;
        for_each_group_device(group, gdev) {
-               ret = __iommu_device_set_domain(group, gdev, new_domain, flags);
+               ret = __iommu_device_set_domain(group, gdev, new_domain,
+                                               group->domain, flags);
                if (ret) {
                        result = ret;
                        /*
@@ -2379,7 +2385,7 @@ static int __iommu_group_set_domain_internal(struct 
iommu_group *group,
                 */
                if (group->domain)
                        WARN_ON(__iommu_device_set_domain(
-                               group, gdev, group->domain,
+                               group, gdev, group->domain, new_domain,
                                IOMMU_SET_DOMAIN_MUST_SUCCEED));
                if (gdev == last_gdev)
                        break;
diff --git a/drivers/iommu/iommufd/selftest.c b/drivers/iommu/iommufd/selftest.c
index 61686603c7693..fa4359bb15e84 100644
--- a/drivers/iommu/iommufd/selftest.c
+++ b/drivers/iommu/iommufd/selftest.c
@@ -216,7 +216,7 @@ static inline struct selftest_obj *to_selftest_obj(struct 
iommufd_object *obj)
 }
 
 static int mock_domain_nop_attach(struct iommu_domain *domain,
-                                 struct device *dev)
+                                 struct device *dev, struct iommu_domain *old)
 {
        struct mock_dev *mdev = to_mock_dev(dev);
        struct mock_viommu *new_viommu = NULL;
diff --git a/drivers/iommu/ipmmu-vmsa.c b/drivers/iommu/ipmmu-vmsa.c
index ffa892f657140..6667ecc331f01 100644
--- a/drivers/iommu/ipmmu-vmsa.c
+++ b/drivers/iommu/ipmmu-vmsa.c
@@ -590,7 +590,7 @@ static void ipmmu_domain_free(struct iommu_domain 
*io_domain)
 }
 
 static int ipmmu_attach_device(struct iommu_domain *io_domain,
-                              struct device *dev)
+                              struct device *dev, struct iommu_domain *old)
 {
        struct iommu_fwspec *fwspec = dev_iommu_fwspec_get(dev);
        struct ipmmu_vmsa_device *mmu = to_ipmmu(dev);
@@ -637,17 +637,17 @@ static int ipmmu_attach_device(struct iommu_domain 
*io_domain,
 }
 
 static int ipmmu_iommu_identity_attach(struct iommu_domain *identity_domain,
-                                      struct device *dev)
+                                      struct device *dev,
+                                      struct iommu_domain *old)
 {
-       struct iommu_domain *io_domain = iommu_get_domain_for_dev(dev);
        struct iommu_fwspec *fwspec = dev_iommu_fwspec_get(dev);
        struct ipmmu_vmsa_domain *domain;
        unsigned int i;
 
-       if (io_domain == identity_domain || !io_domain)
+       if (old == identity_domain || !old)
                return 0;
 
-       domain = to_vmsa_domain(io_domain);
+       domain = to_vmsa_domain(old);
        for (i = 0; i < fwspec->num_ids; ++i)
                ipmmu_utlb_disable(domain, fwspec->ids[i]);
 
diff --git a/drivers/iommu/msm_iommu.c b/drivers/iommu/msm_iommu.c
index 43a61ba021a51..2629fbd0606d3 100644
--- a/drivers/iommu/msm_iommu.c
+++ b/drivers/iommu/msm_iommu.c
@@ -441,19 +441,19 @@ static int msm_iommu_attach_dev(struct iommu_domain 
*domain, struct device *dev)
 }
 
 static int msm_iommu_identity_attach(struct iommu_domain *identity_domain,
-                                    struct device *dev)
+                                    struct device *dev,
+                                    struct iommu_domain *old)
 {
-       struct iommu_domain *domain = iommu_get_domain_for_dev(dev);
        struct msm_priv *priv;
        unsigned long flags;
        struct msm_iommu_dev *iommu;
        struct msm_iommu_ctx_dev *master;
        int ret = 0;
 
-       if (domain == identity_domain || !domain)
+       if (old == identity_domain || !old)
                return 0;
 
-       priv = to_msm_priv(domain);
+       priv = to_msm_priv(old);
        free_io_pgtable_ops(priv->iop);
 
        spin_lock_irqsave(&msm_iommu_lock, flags);
diff --git a/drivers/iommu/mtk_iommu.c b/drivers/iommu/mtk_iommu.c
index 0e0285348d2b8..9747ef1644138 100644
--- a/drivers/iommu/mtk_iommu.c
+++ b/drivers/iommu/mtk_iommu.c
@@ -705,7 +705,7 @@ static void mtk_iommu_domain_free(struct iommu_domain 
*domain)
 }
 
 static int mtk_iommu_attach_device(struct iommu_domain *domain,
-                                  struct device *dev)
+                                  struct device *dev, struct iommu_domain *old)
 {
        struct mtk_iommu_data *data = dev_iommu_priv_get(dev), *frstdata;
        struct mtk_iommu_domain *dom = to_mtk_domain(domain);
@@ -773,12 +773,12 @@ static int mtk_iommu_attach_device(struct iommu_domain 
*domain,
 }
 
 static int mtk_iommu_identity_attach(struct iommu_domain *identity_domain,
-                                    struct device *dev)
+                                    struct device *dev,
+                                    struct iommu_domain *old)
 {
-       struct iommu_domain *domain = iommu_get_domain_for_dev(dev);
        struct mtk_iommu_data *data = dev_iommu_priv_get(dev);
 
-       if (domain == identity_domain || !domain)
+       if (old == identity_domain || !old)
                return 0;
 
        mtk_iommu_config(data, dev, false, 0);
diff --git a/drivers/iommu/mtk_iommu_v1.c b/drivers/iommu/mtk_iommu_v1.c
index 10cc0b1197e80..3b45650263ac3 100644
--- a/drivers/iommu/mtk_iommu_v1.c
+++ b/drivers/iommu/mtk_iommu_v1.c
@@ -303,7 +303,9 @@ static void mtk_iommu_v1_domain_free(struct iommu_domain 
*domain)
        kfree(to_mtk_domain(domain));
 }
 
-static int mtk_iommu_v1_attach_device(struct iommu_domain *domain, struct 
device *dev)
+static int mtk_iommu_v1_attach_device(struct iommu_domain *domain,
+                                     struct device *dev,
+                                     struct iommu_domain *old)
 {
        struct mtk_iommu_v1_data *data = dev_iommu_priv_get(dev);
        struct mtk_iommu_v1_domain *dom = to_mtk_domain(domain);
@@ -329,7 +331,8 @@ static int mtk_iommu_v1_attach_device(struct iommu_domain 
*domain, struct device
 }
 
 static int mtk_iommu_v1_identity_attach(struct iommu_domain *identity_domain,
-                                       struct device *dev)
+                                       struct device *dev,
+                                       struct iommu_domain *old)
 {
        struct mtk_iommu_v1_data *data = dev_iommu_priv_get(dev);
 
diff --git a/drivers/iommu/omap-iommu.c b/drivers/iommu/omap-iommu.c
index 6fb93927bdb98..a38f69debac09 100644
--- a/drivers/iommu/omap-iommu.c
+++ b/drivers/iommu/omap-iommu.c
@@ -1431,8 +1431,8 @@ static void omap_iommu_detach_fini(struct 
omap_iommu_domain *odomain)
        odomain->iommus = NULL;
 }
 
-static int
-omap_iommu_attach_dev(struct iommu_domain *domain, struct device *dev)
+static int omap_iommu_attach_dev(struct iommu_domain *domain,
+                                struct device *dev, struct iommu_domain *old)
 {
        struct omap_iommu_arch_data *arch_data = dev_iommu_priv_get(dev);
        struct omap_iommu_domain *omap_domain = to_omap_domain(domain);
@@ -1536,15 +1536,15 @@ static void _omap_iommu_detach_dev(struct 
omap_iommu_domain *omap_domain,
 }
 
 static int omap_iommu_identity_attach(struct iommu_domain *identity_domain,
-                                     struct device *dev)
+                                     struct device *dev,
+                                     struct iommu_domain *old)
 {
-       struct iommu_domain *domain = iommu_get_domain_for_dev(dev);
        struct omap_iommu_domain *omap_domain;
 
-       if (domain == identity_domain || !domain)
+       if (old == identity_domain || !old)
                return 0;
 
-       omap_domain = to_omap_domain(domain);
+       omap_domain = to_omap_domain(old);
        spin_lock(&omap_domain->lock);
        _omap_iommu_detach_dev(omap_domain, dev);
        spin_unlock(&omap_domain->lock);
diff --git a/drivers/iommu/riscv/iommu.c b/drivers/iommu/riscv/iommu.c
index 2d0d31ba28860..9c58413ad641e 100644
--- a/drivers/iommu/riscv/iommu.c
+++ b/drivers/iommu/riscv/iommu.c
@@ -1319,7 +1319,8 @@ static bool riscv_iommu_pt_supported(struct 
riscv_iommu_device *iommu, int pgd_m
 }
 
 static int riscv_iommu_attach_paging_domain(struct iommu_domain *iommu_domain,
-                                           struct device *dev)
+                                           struct device *dev,
+                                           struct iommu_domain *old)
 {
        struct riscv_iommu_domain *domain = iommu_domain_to_riscv(iommu_domain);
        struct riscv_iommu_device *iommu = dev_to_iommu(dev);
@@ -1424,7 +1425,8 @@ static struct iommu_domain 
*riscv_iommu_alloc_paging_domain(struct device *dev)
 }
 
 static int riscv_iommu_attach_blocking_domain(struct iommu_domain 
*iommu_domain,
-                                             struct device *dev)
+                                             struct device *dev,
+                                             struct iommu_domain *old)
 {
        struct riscv_iommu_device *iommu = dev_to_iommu(dev);
        struct riscv_iommu_info *info = dev_iommu_priv_get(dev);
@@ -1445,7 +1447,8 @@ static struct iommu_domain riscv_iommu_blocking_domain = {
 };
 
 static int riscv_iommu_attach_identity_domain(struct iommu_domain 
*iommu_domain,
-                                             struct device *dev)
+                                             struct device *dev,
+                                             struct iommu_domain *old)
 {
        struct riscv_iommu_device *iommu = dev_to_iommu(dev);
        struct riscv_iommu_info *info = dev_iommu_priv_get(dev);
diff --git a/drivers/iommu/rockchip-iommu.c b/drivers/iommu/rockchip-iommu.c
index 0861dd469bd86..85f3667e797c3 100644
--- a/drivers/iommu/rockchip-iommu.c
+++ b/drivers/iommu/rockchip-iommu.c
@@ -960,7 +960,8 @@ static int rk_iommu_enable(struct rk_iommu *iommu)
 }
 
 static int rk_iommu_identity_attach(struct iommu_domain *identity_domain,
-                                   struct device *dev)
+                                   struct device *dev,
+                                   struct iommu_domain *old)
 {
        struct rk_iommu *iommu;
        struct rk_iommu_domain *rk_domain;
@@ -1005,7 +1006,7 @@ static struct iommu_domain rk_identity_domain = {
 };
 
 static int rk_iommu_attach_device(struct iommu_domain *domain,
-               struct device *dev)
+                                 struct device *dev, struct iommu_domain *old)
 {
        struct rk_iommu *iommu;
        struct rk_iommu_domain *rk_domain = to_rk_domain(domain);
@@ -1026,7 +1027,7 @@ static int rk_iommu_attach_device(struct iommu_domain 
*domain,
        if (iommu->domain == domain)
                return 0;
 
-       ret = rk_iommu_identity_attach(&rk_identity_domain, dev);
+       ret = rk_iommu_identity_attach(&rk_identity_domain, dev, old);
        if (ret)
                return ret;
 
@@ -1041,8 +1042,17 @@ static int rk_iommu_attach_device(struct iommu_domain 
*domain,
                return 0;
 
        ret = rk_iommu_enable(iommu);
-       if (ret)
-               WARN_ON(rk_iommu_identity_attach(&rk_identity_domain, dev));
+       if (ret) {
+               /*
+                * Note rk_iommu_identity_attach() might fail before physically
+                * attaching the dev to iommu->domain, in which case the actual
+                * old domain for this revert should be rk_identity_domain v.s.
+                * iommu->domain. Since rk_iommu_identity_attach() does not care
+                * about the old domain argument for now, this is not a problem.
+                */
+               WARN_ON(rk_iommu_identity_attach(&rk_identity_domain, dev,
+                                                iommu->domain));
+       }
 
        pm_runtime_put(iommu->dev);
 
diff --git a/drivers/iommu/s390-iommu.c b/drivers/iommu/s390-iommu.c
index 9c80d61deb2c0..f2f58bb21720b 100644
--- a/drivers/iommu/s390-iommu.c
+++ b/drivers/iommu/s390-iommu.c
@@ -653,7 +653,8 @@ int zpci_iommu_register_ioat(struct zpci_dev *zdev, u8 
*status)
 }
 
 static int blocking_domain_attach_device(struct iommu_domain *domain,
-                                        struct device *dev)
+                                        struct device *dev,
+                                        struct iommu_domain *old)
 {
        struct zpci_dev *zdev = to_zpci_dev(dev);
        struct s390_domain *s390_domain;
@@ -677,7 +678,8 @@ static int blocking_domain_attach_device(struct 
iommu_domain *domain,
 }
 
 static int s390_iommu_attach_device(struct iommu_domain *domain,
-                                   struct device *dev)
+                                   struct device *dev,
+                                   struct iommu_domain *old)
 {
        struct s390_domain *s390_domain = to_s390_domain(domain);
        struct zpci_dev *zdev = to_zpci_dev(dev);
@@ -1113,7 +1115,8 @@ static int __init s390_iommu_init(void)
 subsys_initcall(s390_iommu_init);
 
 static int s390_attach_dev_identity(struct iommu_domain *domain,
-                                   struct device *dev)
+                                   struct device *dev,
+                                   struct iommu_domain *old)
 {
        struct zpci_dev *zdev = to_zpci_dev(dev);
        u8 status;
diff --git a/drivers/iommu/sprd-iommu.c b/drivers/iommu/sprd-iommu.c
index c7ca1d8a0b153..555d4505c747a 100644
--- a/drivers/iommu/sprd-iommu.c
+++ b/drivers/iommu/sprd-iommu.c
@@ -247,7 +247,8 @@ static void sprd_iommu_domain_free(struct iommu_domain 
*domain)
 }
 
 static int sprd_iommu_attach_device(struct iommu_domain *domain,
-                                   struct device *dev)
+                                   struct device *dev,
+                                   struct iommu_domain *old)
 {
        struct sprd_iommu_device *sdev = dev_iommu_priv_get(dev);
        struct sprd_iommu_domain *dom = to_sprd_domain(domain);
diff --git a/drivers/iommu/sun50i-iommu.c b/drivers/iommu/sun50i-iommu.c
index de10b569d9a94..d3b190be18b5a 100644
--- a/drivers/iommu/sun50i-iommu.c
+++ b/drivers/iommu/sun50i-iommu.c
@@ -771,7 +771,8 @@ static void sun50i_iommu_detach_domain(struct sun50i_iommu 
*iommu,
 }
 
 static int sun50i_iommu_identity_attach(struct iommu_domain *identity_domain,
-                                       struct device *dev)
+                                       struct device *dev,
+                                       struct iommu_domain *old)
 {
        struct sun50i_iommu *iommu = dev_iommu_priv_get(dev);
        struct sun50i_iommu_domain *sun50i_domain;
@@ -797,7 +798,8 @@ static struct iommu_domain sun50i_iommu_identity_domain = {
 };
 
 static int sun50i_iommu_attach_device(struct iommu_domain *domain,
-                                     struct device *dev)
+                                     struct device *dev,
+                                     struct iommu_domain *old)
 {
        struct sun50i_iommu_domain *sun50i_domain = to_sun50i_domain(domain);
        struct sun50i_iommu *iommu;
@@ -813,7 +815,7 @@ static int sun50i_iommu_attach_device(struct iommu_domain 
*domain,
        if (iommu->domain == domain)
                return 0;
 
-       sun50i_iommu_identity_attach(&sun50i_iommu_identity_domain, dev);
+       sun50i_iommu_identity_attach(&sun50i_iommu_identity_domain, dev, old);
 
        sun50i_iommu_attach_domain(iommu, sun50i_domain);
 
diff --git a/drivers/iommu/tegra-smmu.c b/drivers/iommu/tegra-smmu.c
index 36cdd5fbab077..336e0a3ff41fb 100644
--- a/drivers/iommu/tegra-smmu.c
+++ b/drivers/iommu/tegra-smmu.c
@@ -490,7 +490,7 @@ static void tegra_smmu_as_unprepare(struct tegra_smmu *smmu,
 }
 
 static int tegra_smmu_attach_dev(struct iommu_domain *domain,
-                                struct device *dev)
+                                struct device *dev, struct iommu_domain *old)
 {
        struct iommu_fwspec *fwspec = dev_iommu_fwspec_get(dev);
        struct tegra_smmu *smmu = dev_iommu_priv_get(dev);
@@ -524,9 +524,9 @@ static int tegra_smmu_attach_dev(struct iommu_domain 
*domain,
 }
 
 static int tegra_smmu_identity_attach(struct iommu_domain *identity_domain,
-                                     struct device *dev)
+                                     struct device *dev,
+                                     struct iommu_domain *old)
 {
-       struct iommu_domain *domain = iommu_get_domain_for_dev(dev);
        struct iommu_fwspec *fwspec = dev_iommu_fwspec_get(dev);
        struct tegra_smmu_as *as;
        struct tegra_smmu *smmu;
@@ -535,10 +535,10 @@ static int tegra_smmu_identity_attach(struct iommu_domain 
*identity_domain,
        if (!fwspec)
                return -ENODEV;
 
-       if (domain == identity_domain || !domain)
+       if (old == identity_domain || !old)
                return 0;
 
-       as = to_smmu_as(domain);
+       as = to_smmu_as(old);
        smmu = as->smmu;
        for (index = 0; index < fwspec->num_ids; index++) {
                tegra_smmu_disable(smmu, fwspec->ids[index], as->id);
diff --git a/drivers/iommu/virtio-iommu.c b/drivers/iommu/virtio-iommu.c
index 532db1de201ba..bd1af8a77005f 100644
--- a/drivers/iommu/virtio-iommu.c
+++ b/drivers/iommu/virtio-iommu.c
@@ -730,7 +730,8 @@ static struct iommu_domain 
*viommu_domain_alloc_identity(struct device *dev)
        return domain;
 }
 
-static int viommu_attach_dev(struct iommu_domain *domain, struct device *dev)
+static int viommu_attach_dev(struct iommu_domain *domain, struct device *dev,
+                            struct iommu_domain *old)
 {
        int ret = 0;
        struct virtio_iommu_req_attach req;
@@ -781,7 +782,8 @@ static int viommu_attach_dev(struct iommu_domain *domain, 
struct device *dev)
 }
 
 static int viommu_attach_identity_domain(struct iommu_domain *domain,
-                                        struct device *dev)
+                                        struct device *dev,
+                                        struct iommu_domain *old)
 {
        int ret = 0;
        struct virtio_iommu_req_attach req;
-- 
2.43.0


Reply via email to