The per-region pfns array is only used by pinned regions, where it records
the host PFNs returned by pin_user_pages() so they can later be unpinned,
shared, and unshared. Movable regions get their PFNs from HMM on every
fault, and MMIO regions need only a single base PFN. Keeping a 2 MiB-per-1
GiB flexible array around for region types that never read it is pure
overhead.

Convert mreg_pfns to a pointer and allocate it (via vmalloc_array, since it
can be large) only for MSHV_REGION_TYPE_MEM_PINNED. Place it in a union
with mreg_mmio_pfn so MMIO regions reuse the storage for the device base
PFN. Movable regions leave the pointer NULL.

With the flexible array gone, the struct itself becomes fixed-size and the
main allocation can move from vzalloc() to kzalloc().

Gate every mreg_pfns access on mreg_type == MSHV_REGION_TYPE_MEM_PINNED:
share/unshare/invalidate_pfns short-circuit for other types, and the
destructor frees the array only for pinned regions. A NULL check would not
be safe here because the union also stores MMIO regions' mmio_pfn, which is
typically non-zero.

The movable-region fault path no longer copies HMM-collected PFNs into
mreg_pfns; instead it post-processes the temporary HMM array in place
(stamping skipped slots with MSHV_INVALID_PFN) and hands it directly to the
remap helper. Movable regions are now stateless from the kernel's point of
view; the hypervisor's SLAT is the source of truth.

Signed-off-by: Stanislav Kinsburskii <[email protected]>
---
 drivers/hv/mshv_regions.c |   77 +++++++++++++++++++++++----------------------
 drivers/hv/mshv_root.h    |    7 ++--
 2 files changed, 43 insertions(+), 41 deletions(-)

diff --git a/drivers/hv/mshv_regions.c b/drivers/hv/mshv_regions.c
index e20db61e9829f..a4bfec9279ede 100644
--- a/drivers/hv/mshv_regions.c
+++ b/drivers/hv/mshv_regions.c
@@ -42,8 +42,8 @@ static inline bool mshv_pfn_valid(unsigned long pfn)
        return pfn != MSHV_INVALID_PFN;
 }
 
-static void mshv_region_init_pfns_range(struct mshv_region *region,
-                                       u64 pfn_offset, u64 pfn_count)
+static void mshv_region_init_pfns(struct mshv_region *region,
+                                 u64 pfn_offset, u64 pfn_count)
 {
        u64 i;
 
@@ -51,11 +51,6 @@ static void mshv_region_init_pfns_range(struct mshv_region 
*region,
                region->mreg_pfns[i] = MSHV_INVALID_PFN;
 }
 
-void mshv_region_init_pfns(struct mshv_region *region)
-{
-       mshv_region_init_pfns_range(region, 0, region->nr_pfns);
-}
-
 /**
  * mshv_chunk_stride - Compute stride for mapping guest memory
  * @pfn      : The PFN to check for huge page backing
@@ -220,7 +215,7 @@ struct mshv_region *mshv_region_create(struct 
mshv_partition *partition,
        struct mshv_region *region;
        int ret = 0;
 
-       region = vzalloc(struct_size(region, mreg_pfns, nr_pfns));
+       region = kzalloc_obj(struct mshv_region);
        if (!region)
                return ERR_PTR(-ENOMEM);
 
@@ -235,8 +230,6 @@ struct mshv_region *mshv_region_create(struct 
mshv_partition *partition,
        if (flags & BIT(MSHV_SET_MEM_BIT_EXECUTABLE))
                region->hv_map_flags |= HV_MAP_GPA_EXECUTABLE;
 
-       mshv_region_init_pfns(region);
-
        mutex_init(&region->mreg_mutex);
        kref_init(&region->mreg_refcount);
 
@@ -248,6 +241,12 @@ struct mshv_region *mshv_region_create(struct 
mshv_partition *partition,
                                                   &mshv_region_mni_ops);
                break;
        case MSHV_REGION_TYPE_MEM_PINNED:
+               region->mreg_pfns = vmalloc_array(nr_pfns, sizeof(unsigned 
long));
+               if (!region->mreg_pfns) {
+                       ret = -ENOMEM;
+                       break;
+               }
+               mshv_region_init_pfns(region, 0, region->nr_pfns);
                break;
        case MSHV_REGION_TYPE_MMIO:
                region->mreg_mmio_pfn = mmio_pfn;
@@ -262,7 +261,7 @@ struct mshv_region *mshv_region_create(struct 
mshv_partition *partition,
        return region;
 
 free_region:
-       vfree(region);
+       kfree(region);
        return ERR_PTR(ret);
 }
 
@@ -289,6 +288,9 @@ static int mshv_region_share(struct mshv_region *region)
 {
        u32 flags = HV_MODIFY_SPA_PAGE_HOST_ACCESS_MAKE_SHARED;
 
+       if (region->mreg_type != MSHV_REGION_TYPE_MEM_PINNED)
+               return -EINVAL;
+
        return mshv_region_process_range(region, flags,
                                         0, region->nr_pfns,
                                         region->mreg_pfns,
@@ -317,6 +319,9 @@ static int mshv_region_unshare(struct mshv_region *region)
 {
        u32 flags = HV_MODIFY_SPA_PAGE_HOST_ACCESS_MAKE_EXCLUSIVE;
 
+       if (region->mreg_type != MSHV_REGION_TYPE_MEM_PINNED)
+               return -EINVAL;
+
        return mshv_region_process_range(region, flags,
                                         0, region->nr_pfns,
                                         region->mreg_pfns,
@@ -357,29 +362,20 @@ static int mshv_region_remap_pfns(struct mshv_region 
*region,
                                         mshv_region_chunk_remap);
 }
 
-static int mshv_region_map(struct mshv_region *region)
-{
-       u32 map_flags = region->hv_map_flags;
-
-       return mshv_region_remap_pfns(region, map_flags,
-                                     0, region->nr_pfns,
-                                     region->mreg_pfns);
-}
-
 static void mshv_region_invalidate_pfns(struct mshv_region *region,
                                        u64 pfn_offset, u64 pfn_count)
 {
        u64 i;
 
-       for (i = pfn_offset; i < pfn_offset + pfn_count; i++) {
-               if (!mshv_pfn_valid(region->mreg_pfns[i]))
-                       continue;
+       if (region->mreg_type != MSHV_REGION_TYPE_MEM_PINNED)
+               return;
 
-               if (region->mreg_type == MSHV_REGION_TYPE_MEM_PINNED)
+       for (i = pfn_offset; i < pfn_offset + pfn_count; i++) {
+               if (mshv_pfn_valid(region->mreg_pfns[i]))
                        unpin_user_page(pfn_to_page(region->mreg_pfns[i]));
        }
 
-       mshv_region_init_pfns_range(region, pfn_offset, pfn_count);
+       mshv_region_init_pfns(region, pfn_offset, pfn_count);
 }
 
 static void mshv_region_invalidate(struct mshv_region *region)
@@ -516,7 +512,9 @@ static void mshv_region_destroy(struct kref *ref)
 
        mshv_region_invalidate(region);
 
-       vfree(region);
+       if (region->mreg_type == MSHV_REGION_TYPE_MEM_PINNED)
+               vfree(region->mreg_pfns);
+       kfree(region);
 }
 
 void mshv_region_put(struct mshv_region *region)
@@ -634,10 +632,9 @@ static int mshv_region_hmm_fault_and_lock(struct 
mshv_region *region,
  *   leaving missing pages as invalid PFN markers.
  *   Used for initial region setup.
  *
- * Collected PFNs are stored in region->mreg_pfns[] with HMM bookkeeping
- * flags cleared, then the range is mapped into the hypervisor. Present
- * PFNs get mapped with region access permissions; missing PFNs (invalid
- * entries) get mapped with no-access permissions.
+ * HMM bookkeeping flags are stripped from collected PFNs before mapping.
+ * Present PFNs get mapped with region access permissions; missing PFNs
+ * (marked as MSHV_INVALID_PFN) get mapped with no-access permissions.
  *
  * Return: 0 on success, negative errno on failure.
  */
@@ -666,20 +663,24 @@ static int mshv_region_collect_and_map(struct mshv_region 
*region,
                goto out;
 
        for (i = 0; i < pfn_count; i++) {
-               if (!(pfns[i] & HMM_PFN_VALID))
+               if (!(pfns[i] & HMM_PFN_VALID)) {
+                       pfns[i] = MSHV_INVALID_PFN;
                        continue;
+               }
                /* Skip read-only pages to avoid bypassing COW */
                if (!do_fault &&
                    (region->hv_map_flags & HV_MAP_GPA_WRITABLE) &&
-                   !(pfns[i] & HMM_PFN_WRITE))
+                   !(pfns[i] & HMM_PFN_WRITE)) {
+                       pfns[i] = MSHV_INVALID_PFN;
                        continue;
+               }
                /* Drop HMM_PFN_* flags to ensure PFNs are valid. */
-               region->mreg_pfns[pfn_offset + i] = pfns[i] & ~HMM_PFN_FLAGS;
+               pfns[i] &= ~HMM_PFN_FLAGS;
        }
 
        ret = mshv_region_remap_pfns(region, region->hv_map_flags,
                                     pfn_offset, pfn_count,
-                                    region->mreg_pfns + pfn_offset);
+                                    pfns);
 
        mutex_unlock(&region->mreg_mutex);
 out:
@@ -781,8 +782,6 @@ static bool mshv_region_interval_invalidate(struct 
mmu_interval_notifier *mni,
        if (ret)
                goto out_unlock;
 
-       mshv_region_invalidate_pfns(region, pfn_offset, pfn_count);
-
        mutex_unlock(&region->mreg_mutex);
 
        return true;
@@ -845,7 +844,9 @@ static int mshv_map_pinned_region(struct mshv_region 
*region)
                }
        }
 
-       ret = mshv_region_map(region);
+       ret = mshv_region_remap_pfns(region, region->hv_map_flags,
+                                    0, region->nr_pfns,
+                                    region->mreg_pfns);
        if (ret)
                goto share_region;
 
@@ -869,7 +870,7 @@ static int mshv_map_pinned_region(struct mshv_region 
*region)
                 * is intentional; unpinning host-inaccessible pages would be
                 * unsafe).
                 */
-               mshv_region_init_pfns(region);
+               mshv_region_init_pfns(region, 0, region->nr_pfns);
                goto err_out;
        }
 err_out:
diff --git a/drivers/hv/mshv_root.h b/drivers/hv/mshv_root.h
index e9bd18013b486..d79dfaac88af9 100644
--- a/drivers/hv/mshv_root.h
+++ b/drivers/hv/mshv_root.h
@@ -93,8 +93,10 @@ struct mshv_region {
        enum mshv_region_type mreg_type;
        struct mmu_interval_notifier mreg_mni;
        struct mutex mreg_mutex;        /* protects region PFNs remapping */
-       u64 mreg_mmio_pfn;
-       unsigned long mreg_pfns[];
+       union {
+               unsigned long *mreg_pfns;
+               u64 mreg_mmio_pfn;
+       };
 };
 
 struct mshv_irq_ack_notifier {
@@ -375,7 +377,6 @@ struct mshv_region *mshv_region_create(struct 
mshv_partition *partition,
                                       u64 guest_pfn, u64 nr_pfns,
                                       u64 uaddr, u32 flags,
                                       unsigned long mmio_pfn);
-void mshv_region_init_pfns(struct mshv_region *region);
 void mshv_region_put(struct mshv_region *region);
 int mshv_region_get(struct mshv_region *region);
 bool mshv_region_handle_gfn_fault(struct mshv_region *region, u64 gfn);



Reply via email to