virtio supports a single iommu instance with multiple ids.

It has a combined ACPI (via the VIOT table) and OF probe path, add
iommu_viot_get_single_iommu() to respresent this.

It already has a per-instance structure, extend it with the ids[]
array and use iommu_fw_alloc_per_device_ids() to populate it.

Convert the rest of the funcs from calling dev_iommu_fwspec_get() to using
he per-device data and remove all use of fwspec.

Signed-off-by: Jason Gunthorpe <j...@nvidia.com>
---
 drivers/iommu/virtio-iommu.c | 67 +++++++++++++-----------------------
 1 file changed, 23 insertions(+), 44 deletions(-)

diff --git a/drivers/iommu/virtio-iommu.c b/drivers/iommu/virtio-iommu.c
index b1a7b14a6c7a2f..767919bf848999 100644
--- a/drivers/iommu/virtio-iommu.c
+++ b/drivers/iommu/virtio-iommu.c
@@ -77,6 +77,8 @@ struct viommu_endpoint {
        struct viommu_dev               *viommu;
        struct viommu_domain            *vdomain;
        struct list_head                resv_regions;
+       unsigned int                    num_ids;
+       u32                             ids[] __counted_by(num_ids);
 };
 
 struct viommu_request {
@@ -510,19 +512,16 @@ static int viommu_add_resv_mem(struct viommu_endpoint 
*vdev,
        return 0;
 }
 
-static int viommu_probe_endpoint(struct viommu_dev *viommu, struct device *dev)
+static int viommu_probe_endpoint(struct viommu_endpoint *vdev)
 {
        int ret;
        u16 type, len;
        size_t cur = 0;
        size_t probe_len;
+       struct device *dev = vdev->dev;
        struct virtio_iommu_req_probe *probe;
        struct virtio_iommu_probe_property *prop;
-       struct iommu_fwspec *fwspec = dev_iommu_fwspec_get(dev);
-       struct viommu_endpoint *vdev = dev_iommu_priv_get(dev);
-
-       if (!fwspec->num_ids)
-               return -EINVAL;
+       struct viommu_dev *viommu = vdev->viommu;
 
        probe_len = sizeof(*probe) + viommu->probe_size +
                    sizeof(struct virtio_iommu_req_tail);
@@ -535,7 +534,7 @@ static int viommu_probe_endpoint(struct viommu_dev *viommu, 
struct device *dev)
         * For now, assume that properties of an endpoint that outputs multiple
         * IDs are consistent. Only probe the first one.
         */
-       probe->endpoint = cpu_to_le32(fwspec->ids[0]);
+       probe->endpoint = cpu_to_le32(vdev->ids[0]);
 
        ret = viommu_send_req_sync(viommu, probe, probe_len);
        if (ret)
@@ -721,7 +720,6 @@ static int viommu_attach_dev(struct iommu_domain *domain, 
struct device *dev)
        int i;
        int ret = 0;
        struct virtio_iommu_req_attach req;
-       struct iommu_fwspec *fwspec = dev_iommu_fwspec_get(dev);
        struct viommu_endpoint *vdev = dev_iommu_priv_get(dev);
        struct viommu_domain *vdomain = to_viommu_domain(domain);
 
@@ -763,8 +761,8 @@ static int viommu_attach_dev(struct iommu_domain *domain, 
struct device *dev)
        if (vdomain->bypass)
                req.flags |= cpu_to_le32(VIRTIO_IOMMU_ATTACH_F_BYPASS);
 
-       for (i = 0; i < fwspec->num_ids; i++) {
-               req.endpoint = cpu_to_le32(fwspec->ids[i]);
+       for (i = 0; i < vdev->num_ids; i++) {
+               req.endpoint = cpu_to_le32(vdev->ids[i]);
 
                ret = viommu_send_req_sync(vdomain->viommu, &req, sizeof(req));
                if (ret)
@@ -792,7 +790,6 @@ static void viommu_detach_dev(struct viommu_endpoint *vdev)
        int i;
        struct virtio_iommu_req_detach req;
        struct viommu_domain *vdomain = vdev->vdomain;
-       struct iommu_fwspec *fwspec = dev_iommu_fwspec_get(vdev->dev);
 
        if (!vdomain)
                return;
@@ -802,8 +799,8 @@ static void viommu_detach_dev(struct viommu_endpoint *vdev)
                .domain         = cpu_to_le32(vdomain->id),
        };
 
-       for (i = 0; i < fwspec->num_ids; i++) {
-               req.endpoint = cpu_to_le32(fwspec->ids[i]);
+       for (i = 0; i < vdev->num_ids; i++) {
+               req.endpoint = cpu_to_le32(vdev->ids[i]);
                WARN_ON(viommu_send_req_sync(vdev->viommu, &req, sizeof(req)));
        }
        vdomain->nr_endpoints--;
@@ -974,34 +971,21 @@ static void viommu_get_resv_regions(struct device *dev, 
struct list_head *head)
 static struct iommu_ops viommu_ops;
 static struct virtio_driver virtio_iommu_drv;
 
-static int viommu_match_node(struct device *dev, const void *data)
-{
-       return device_match_fwnode(dev->parent, data);
-}
-
-static struct viommu_dev *viommu_get_by_fwnode(struct fwnode_handle *fwnode)
-{
-       struct device *dev = driver_find_device(&virtio_iommu_drv.driver, NULL,
-                                               fwnode, viommu_match_node);
-       put_device(dev);
-
-       return dev ? dev_to_virtio(dev)->priv : NULL;
-}
-
-static struct iommu_device *viommu_probe_device(struct device *dev)
+static struct iommu_device *viommu_probe_device(struct iommu_probe_info *pinf)
 {
        int ret;
+       struct viommu_dev *viommu;
        struct viommu_endpoint *vdev;
-       struct viommu_dev *viommu = NULL;
-       struct iommu_fwspec *fwspec = dev_iommu_fwspec_get(dev);
+       struct device *dev = pinf->dev;
 
-       viommu = viommu_get_by_fwnode(fwspec->iommu_fwnode);
-       if (!viommu)
-               return ERR_PTR(-ENODEV);
+       viommu = iommu_viot_get_single_iommu(pinf, &viommu_ops,
+                                            struct viommu_dev, iommu);
+       if (IS_ERR(viommu))
+               return ERR_CAST(viommu);
 
-       vdev = kzalloc(sizeof(*vdev), GFP_KERNEL);
-       if (!vdev)
-               return ERR_PTR(-ENOMEM);
+       vdev = iommu_fw_alloc_per_device_ids(pinf, vdev);
+       if (IS_ERR(vdev))
+               return ERR_CAST(vdev);
 
        vdev->dev = dev;
        vdev->viommu = viommu;
@@ -1010,7 +994,7 @@ static struct iommu_device *viommu_probe_device(struct 
device *dev)
 
        if (viommu->probe_size) {
                /* Get additional information for this endpoint */
-               ret = viommu_probe_endpoint(viommu, dev);
+               ret = viommu_probe_endpoint(vdev);
                if (ret)
                        goto err_free_dev;
        }
@@ -1050,11 +1034,6 @@ static struct iommu_group *viommu_device_group(struct 
device *dev)
                return generic_device_group(dev);
 }
 
-static int viommu_of_xlate(struct device *dev, struct of_phandle_args *args)
-{
-       return iommu_fwspec_add_ids(dev, args->args, 1);
-}
-
 static bool viommu_capable(struct device *dev, enum iommu_cap cap)
 {
        switch (cap) {
@@ -1070,12 +1049,12 @@ static bool viommu_capable(struct device *dev, enum 
iommu_cap cap)
 static struct iommu_ops viommu_ops = {
        .capable                = viommu_capable,
        .domain_alloc           = viommu_domain_alloc,
-       .probe_device           = viommu_probe_device,
+       .probe_device_pinf      = viommu_probe_device,
        .probe_finalize         = viommu_probe_finalize,
        .release_device         = viommu_release_device,
        .device_group           = viommu_device_group,
        .get_resv_regions       = viommu_get_resv_regions,
-       .of_xlate               = viommu_of_xlate,
+       .of_xlate               = iommu_dummy_of_xlate,
        .owner                  = THIS_MODULE,
        .default_domain_ops = &(const struct iommu_domain_ops) {
                .attach_dev             = viommu_attach_dev,
-- 
2.42.0


Reply via email to