Add a wrapper for the `pci_sriov_get_totalvfs()` helper, allowing drivers
to query the number of total SR-IOV virtual functions a PCI device
supports.

Cc: Dirk Behme <[email protected]>
Cc: Alexandre Courbot <[email protected]>
Signed-off-by: Zhi Wang <[email protected]>
---
 rust/helpers/pci.c |  7 +++++++
 rust/kernel/pci.rs | 14 ++++++++++++++
 2 files changed, 21 insertions(+)

diff --git a/rust/helpers/pci.c b/rust/helpers/pci.c
index 5043c9909d44..a3072fbe2871 100644
--- a/rust/helpers/pci.c
+++ b/rust/helpers/pci.c
@@ -29,6 +29,13 @@ __rust_helper u32 rust_helper_pci_ext_cap_next(u32 header)
        return PCI_EXT_CAP_NEXT(header);
 }
 
+#ifndef CONFIG_PCI_IOV
+__rust_helper int rust_helper_pci_sriov_get_totalvfs(struct pci_dev *dev)
+{
+       return pci_sriov_get_totalvfs(dev);
+}
+#endif
+
 #ifndef CONFIG_PCI_MSI
 __rust_helper int rust_helper_pci_alloc_irq_vectors(struct pci_dev *dev,
                                                    unsigned int min_vecs,
diff --git a/rust/kernel/pci.rs b/rust/kernel/pci.rs
index fc9c8e2077b2..c787f62b7f53 100644
--- a/rust/kernel/pci.rs
+++ b/rust/kernel/pci.rs
@@ -450,6 +450,20 @@ pub fn pci_class(&self) -> Class {
         // SAFETY: `self.as_raw` is a valid pointer to a `struct pci_dev`.
         Class::from_raw(unsafe { (*self.as_raw()).class })
     }
+
+    /// Returns total number of VFs, or `Err(ENODEV)` if none available.
+    pub fn sriov_get_totalvfs(&self) -> Result<u16> {
+        // SAFETY: `self.as_raw()` is a valid pointer to a `struct pci_dev`.
+        let vfs = unsafe { bindings::pci_sriov_get_totalvfs(self.as_raw()) };
+
+        if vfs == 0 {
+            return Err(ENODEV);
+        }
+
+        // `pci_sriov_get_totalvfs` reads from the SR-IOV total_VFs field (u16
+        // in the PCI spec), so non-zero values always fit in u16.
+        Ok(vfs.try_into()?)
+    }
 }
 
 impl Device<device::Core> {
-- 
2.51.0

Reply via email to