The iommu aux-domain api's work only when IOMMU_DEV_FEAT_AUX is enabled
for the device. Add this check to avoid misuse.

Signed-off-by: Lu Baolu <baolu...@linux.intel.com>
---
 drivers/iommu/iommu.c | 16 +++++++++-------
 1 file changed, 9 insertions(+), 7 deletions(-)

diff --git a/drivers/iommu/iommu.c b/drivers/iommu/iommu.c
index 1ed1e14a1f0c..e1fdd3531d65 100644
--- a/drivers/iommu/iommu.c
+++ b/drivers/iommu/iommu.c
@@ -2725,11 +2725,13 @@ EXPORT_SYMBOL_GPL(iommu_dev_feature_enabled);
  */
 int iommu_aux_attach_device(struct iommu_domain *domain, struct device *dev)
 {
-       int ret = -ENODEV;
+       int ret;
 
-       if (domain->ops->aux_attach_dev)
-               ret = domain->ops->aux_attach_dev(domain, dev);
+       if (!iommu_dev_feature_enabled(dev, IOMMU_DEV_FEAT_AUX) ||
+           !domain->ops->aux_attach_dev)
+               return -ENODEV;
 
+       ret = domain->ops->aux_attach_dev(domain, dev);
        if (!ret)
                trace_attach_device_to_domain(dev);
 
@@ -2748,12 +2750,12 @@ EXPORT_SYMBOL_GPL(iommu_aux_detach_device);
 
 int iommu_aux_get_pasid(struct iommu_domain *domain, struct device *dev)
 {
-       int ret = -ENODEV;
+       if (!iommu_dev_feature_enabled(dev, IOMMU_DEV_FEAT_AUX) ||
+           !domain->ops->aux_get_pasid)
+               return -ENODEV;
 
-       if (domain->ops->aux_get_pasid)
-               ret = domain->ops->aux_get_pasid(domain, dev);
+       return domain->ops->aux_get_pasid(domain, dev);
 
-       return ret;
 }
 EXPORT_SYMBOL_GPL(iommu_aux_get_pasid);
 
-- 
2.17.1

Reply via email to