Track the kvm pointer and its refcount in viommu core. The kvm pointer will be used later to support TSM Bind feature, which tells the secure firmware the connection between a vPCI device and a CoCo VM.
There is existing need to reference kvm pointer in viommu [1], but in that series kvm pointer is used & tracked in platform iommu drivers. While in Confidential Computing (CC) case, viommu should manage a generic routine for TSM Bind, i.e. call pci_tsm_bind(pdev, kvm, tdi_id) So it is better the viommu core keeps and tracks the kvm pointer. [1] https://lore.kernel.org/all/20250319173202.78988-5-shameerali.kolothum.th...@huawei.com/ Signed-off-by: Lu Baolu <baolu...@linux.intel.com> Signed-off-by: Xu Yilun <yilun...@linux.intel.com> --- drivers/iommu/iommufd/viommu.c | 62 ++++++++++++++++++++++++++++++++++ include/linux/iommufd.h | 3 ++ 2 files changed, 65 insertions(+) diff --git a/drivers/iommu/iommufd/viommu.c b/drivers/iommu/iommufd/viommu.c index 488905989b7c..2fcef3f8d1a5 100644 --- a/drivers/iommu/iommufd/viommu.c +++ b/drivers/iommu/iommufd/viommu.c @@ -1,8 +1,68 @@ // SPDX-License-Identifier: GPL-2.0-only /* Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES */ +#if IS_ENABLED(CONFIG_KVM) +#include <linux/kvm_host.h> +#endif + #include "iommufd_private.h" +#if IS_ENABLED(CONFIG_KVM) +static void viommu_get_kvm_safe(struct iommufd_viommu *viommu, struct kvm *kvm) +{ + void (*pfn)(struct kvm *kvm); + bool (*fn)(struct kvm *kvm); + bool ret; + + if (!kvm) + return; + + pfn = symbol_get(kvm_put_kvm); + if (WARN_ON(!pfn)) + return; + + fn = symbol_get(kvm_get_kvm_safe); + if (WARN_ON(!fn)) { + symbol_put(kvm_put_kvm); + return; + } + + ret = fn(kvm); + symbol_put(kvm_get_kvm_safe); + if (!ret) { + symbol_put(kvm_put_kvm); + return; + } + + viommu->put_kvm = pfn; + viommu->kvm = kvm; +} + +static void viommu_put_kvm(struct iommufd_viommu *viommu) +{ + if (!viommu->kvm) + return; + + if (WARN_ON(!viommu->put_kvm)) + goto clear; + + viommu->put_kvm(viommu->kvm); + viommu->put_kvm = NULL; + symbol_put(kvm_put_kvm); + +clear: + viommu->kvm = NULL; +} +#else +static void viommu_get_kvm_safe(struct iommufd_viommu *viommu, struct kvm *kvm) +{ +} + +static void viommu_put_kvm(struct iommufd_viommu *viommu) +{ +} +#endif + void iommufd_viommu_destroy(struct iommufd_object *obj) { struct iommufd_viommu *viommu = @@ -10,6 +70,7 @@ void iommufd_viommu_destroy(struct iommufd_object *obj) if (viommu->ops && viommu->ops->destroy) viommu->ops->destroy(viommu); + viommu_put_kvm(viommu); refcount_dec(&viommu->hwpt->common.obj.users); xa_destroy(&viommu->vdevs); } @@ -68,6 +129,7 @@ int iommufd_viommu_alloc_ioctl(struct iommufd_ucmd *ucmd) * on its own. */ viommu->iommu_dev = __iommu_get_iommu_dev(idev->dev); + viommu_get_kvm_safe(viommu, idev->kvm); cmd->out_viommu_id = viommu->obj.id; rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd)); diff --git a/include/linux/iommufd.h b/include/linux/iommufd.h index 2b2d6095309c..2712421802b9 100644 --- a/include/linux/iommufd.h +++ b/include/linux/iommufd.h @@ -104,6 +104,9 @@ struct iommufd_viommu { struct rw_semaphore veventqs_rwsem; unsigned int type; + + struct kvm *kvm; + void (*put_kvm)(struct kvm *kvm); }; /** -- 2.25.1