Track the kvm pointer and its refcount in viommu core. The kvm pointer
will be used later to support TSM Bind feature, which tells the secure
firmware the connection between a vPCI device and a CoCo VM.

There is existing need to reference kvm pointer in viommu [1], but in
that series kvm pointer is used & tracked in platform iommu drivers.
While in Confidential Computing (CC) case, viommu should manage a
generic routine for TSM Bind, i.e. call pci_tsm_bind(pdev, kvm, tdi_id)
So it is better the viommu core keeps and tracks the kvm pointer.

[1] 
https://lore.kernel.org/all/20250319173202.78988-5-shameerali.kolothum.th...@huawei.com/

Signed-off-by: Lu Baolu <baolu...@linux.intel.com>
Signed-off-by: Xu Yilun <yilun...@linux.intel.com>
---
 drivers/iommu/iommufd/viommu.c | 62 ++++++++++++++++++++++++++++++++++
 include/linux/iommufd.h        |  3 ++
 2 files changed, 65 insertions(+)

diff --git a/drivers/iommu/iommufd/viommu.c b/drivers/iommu/iommufd/viommu.c
index 488905989b7c..2fcef3f8d1a5 100644
--- a/drivers/iommu/iommufd/viommu.c
+++ b/drivers/iommu/iommufd/viommu.c
@@ -1,8 +1,68 @@
 // SPDX-License-Identifier: GPL-2.0-only
 /* Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES
  */
+#if IS_ENABLED(CONFIG_KVM)
+#include <linux/kvm_host.h>
+#endif
+
 #include "iommufd_private.h"
 
+#if IS_ENABLED(CONFIG_KVM)
+static void viommu_get_kvm_safe(struct iommufd_viommu *viommu, struct kvm *kvm)
+{
+       void (*pfn)(struct kvm *kvm);
+       bool (*fn)(struct kvm *kvm);
+       bool ret;
+
+       if (!kvm)
+               return;
+
+       pfn = symbol_get(kvm_put_kvm);
+       if (WARN_ON(!pfn))
+               return;
+
+       fn = symbol_get(kvm_get_kvm_safe);
+       if (WARN_ON(!fn)) {
+               symbol_put(kvm_put_kvm);
+               return;
+       }
+
+       ret = fn(kvm);
+       symbol_put(kvm_get_kvm_safe);
+       if (!ret) {
+               symbol_put(kvm_put_kvm);
+               return;
+       }
+
+       viommu->put_kvm = pfn;
+       viommu->kvm = kvm;
+}
+
+static void viommu_put_kvm(struct iommufd_viommu *viommu)
+{
+       if (!viommu->kvm)
+               return;
+
+       if (WARN_ON(!viommu->put_kvm))
+               goto clear;
+
+       viommu->put_kvm(viommu->kvm);
+       viommu->put_kvm = NULL;
+       symbol_put(kvm_put_kvm);
+
+clear:
+       viommu->kvm = NULL;
+}
+#else
+static void viommu_get_kvm_safe(struct iommufd_viommu *viommu, struct kvm *kvm)
+{
+}
+
+static void viommu_put_kvm(struct iommufd_viommu *viommu)
+{
+}
+#endif
+
 void iommufd_viommu_destroy(struct iommufd_object *obj)
 {
        struct iommufd_viommu *viommu =
@@ -10,6 +70,7 @@ void iommufd_viommu_destroy(struct iommufd_object *obj)
 
        if (viommu->ops && viommu->ops->destroy)
                viommu->ops->destroy(viommu);
+       viommu_put_kvm(viommu);
        refcount_dec(&viommu->hwpt->common.obj.users);
        xa_destroy(&viommu->vdevs);
 }
@@ -68,6 +129,7 @@ int iommufd_viommu_alloc_ioctl(struct iommufd_ucmd *ucmd)
         * on its own.
         */
        viommu->iommu_dev = __iommu_get_iommu_dev(idev->dev);
+       viommu_get_kvm_safe(viommu, idev->kvm);
 
        cmd->out_viommu_id = viommu->obj.id;
        rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
diff --git a/include/linux/iommufd.h b/include/linux/iommufd.h
index 2b2d6095309c..2712421802b9 100644
--- a/include/linux/iommufd.h
+++ b/include/linux/iommufd.h
@@ -104,6 +104,9 @@ struct iommufd_viommu {
        struct rw_semaphore veventqs_rwsem;
 
        unsigned int type;
+
+       struct kvm *kvm;
+       void (*put_kvm)(struct kvm *kvm);
 };
 
 /**
-- 
2.25.1

Reply via email to