The per-region pfns array is only used by pinned regions, where it records the host PFNs returned by pin_user_pages() so they can later be unpinned, shared, and unshared. Movable regions get their PFNs from HMM on every fault, and MMIO regions need only a single base PFN. Keeping a 2 MiB-per-1 GiB flexible array around for region types that never read it is pure overhead.
Convert mreg_pfns to a pointer and allocate it (via vmalloc_array, since it can be large) only for MSHV_REGION_TYPE_MEM_PINNED. Place it in a union with mreg_mmio_pfn so MMIO regions reuse the storage for the device base PFN. Movable regions leave the pointer NULL. With the flexible array gone, the struct itself becomes fixed-size and the main allocation can move from vzalloc() to kzalloc(). Gate every mreg_pfns access on mreg_type == MSHV_REGION_TYPE_MEM_PINNED: share/unshare/invalidate_pfns short-circuit for other types, and the destructor frees the array only for pinned regions. A NULL check would not be safe here because the union also stores MMIO regions' mmio_pfn, which is typically non-zero. The movable-region fault path no longer copies HMM-collected PFNs into mreg_pfns; instead it post-processes the temporary HMM array in place (stamping skipped slots with MSHV_INVALID_PFN) and hands it directly to the remap helper. Movable regions are now stateless from the kernel's point of view; the hypervisor's SLAT is the source of truth. Signed-off-by: Stanislav Kinsburskii <[email protected]> --- drivers/hv/mshv_regions.c | 77 +++++++++++++++++++++++---------------------- drivers/hv/mshv_root.h | 7 ++-- 2 files changed, 43 insertions(+), 41 deletions(-) diff --git a/drivers/hv/mshv_regions.c b/drivers/hv/mshv_regions.c index e20db61e9829f..a4bfec9279ede 100644 --- a/drivers/hv/mshv_regions.c +++ b/drivers/hv/mshv_regions.c @@ -42,8 +42,8 @@ static inline bool mshv_pfn_valid(unsigned long pfn) return pfn != MSHV_INVALID_PFN; } -static void mshv_region_init_pfns_range(struct mshv_region *region, - u64 pfn_offset, u64 pfn_count) +static void mshv_region_init_pfns(struct mshv_region *region, + u64 pfn_offset, u64 pfn_count) { u64 i; @@ -51,11 +51,6 @@ static void mshv_region_init_pfns_range(struct mshv_region *region, region->mreg_pfns[i] = MSHV_INVALID_PFN; } -void mshv_region_init_pfns(struct mshv_region *region) -{ - mshv_region_init_pfns_range(region, 0, region->nr_pfns); -} - /** * mshv_chunk_stride - Compute stride for mapping guest memory * @pfn : The PFN to check for huge page backing @@ -220,7 +215,7 @@ struct mshv_region *mshv_region_create(struct mshv_partition *partition, struct mshv_region *region; int ret = 0; - region = vzalloc(struct_size(region, mreg_pfns, nr_pfns)); + region = kzalloc_obj(struct mshv_region); if (!region) return ERR_PTR(-ENOMEM); @@ -235,8 +230,6 @@ struct mshv_region *mshv_region_create(struct mshv_partition *partition, if (flags & BIT(MSHV_SET_MEM_BIT_EXECUTABLE)) region->hv_map_flags |= HV_MAP_GPA_EXECUTABLE; - mshv_region_init_pfns(region); - mutex_init(®ion->mreg_mutex); kref_init(®ion->mreg_refcount); @@ -248,6 +241,12 @@ struct mshv_region *mshv_region_create(struct mshv_partition *partition, &mshv_region_mni_ops); break; case MSHV_REGION_TYPE_MEM_PINNED: + region->mreg_pfns = vmalloc_array(nr_pfns, sizeof(unsigned long)); + if (!region->mreg_pfns) { + ret = -ENOMEM; + break; + } + mshv_region_init_pfns(region, 0, region->nr_pfns); break; case MSHV_REGION_TYPE_MMIO: region->mreg_mmio_pfn = mmio_pfn; @@ -262,7 +261,7 @@ struct mshv_region *mshv_region_create(struct mshv_partition *partition, return region; free_region: - vfree(region); + kfree(region); return ERR_PTR(ret); } @@ -289,6 +288,9 @@ static int mshv_region_share(struct mshv_region *region) { u32 flags = HV_MODIFY_SPA_PAGE_HOST_ACCESS_MAKE_SHARED; + if (region->mreg_type != MSHV_REGION_TYPE_MEM_PINNED) + return -EINVAL; + return mshv_region_process_range(region, flags, 0, region->nr_pfns, region->mreg_pfns, @@ -317,6 +319,9 @@ static int mshv_region_unshare(struct mshv_region *region) { u32 flags = HV_MODIFY_SPA_PAGE_HOST_ACCESS_MAKE_EXCLUSIVE; + if (region->mreg_type != MSHV_REGION_TYPE_MEM_PINNED) + return -EINVAL; + return mshv_region_process_range(region, flags, 0, region->nr_pfns, region->mreg_pfns, @@ -357,29 +362,20 @@ static int mshv_region_remap_pfns(struct mshv_region *region, mshv_region_chunk_remap); } -static int mshv_region_map(struct mshv_region *region) -{ - u32 map_flags = region->hv_map_flags; - - return mshv_region_remap_pfns(region, map_flags, - 0, region->nr_pfns, - region->mreg_pfns); -} - static void mshv_region_invalidate_pfns(struct mshv_region *region, u64 pfn_offset, u64 pfn_count) { u64 i; - for (i = pfn_offset; i < pfn_offset + pfn_count; i++) { - if (!mshv_pfn_valid(region->mreg_pfns[i])) - continue; + if (region->mreg_type != MSHV_REGION_TYPE_MEM_PINNED) + return; - if (region->mreg_type == MSHV_REGION_TYPE_MEM_PINNED) + for (i = pfn_offset; i < pfn_offset + pfn_count; i++) { + if (mshv_pfn_valid(region->mreg_pfns[i])) unpin_user_page(pfn_to_page(region->mreg_pfns[i])); } - mshv_region_init_pfns_range(region, pfn_offset, pfn_count); + mshv_region_init_pfns(region, pfn_offset, pfn_count); } static void mshv_region_invalidate(struct mshv_region *region) @@ -516,7 +512,9 @@ static void mshv_region_destroy(struct kref *ref) mshv_region_invalidate(region); - vfree(region); + if (region->mreg_type == MSHV_REGION_TYPE_MEM_PINNED) + vfree(region->mreg_pfns); + kfree(region); } void mshv_region_put(struct mshv_region *region) @@ -634,10 +632,9 @@ static int mshv_region_hmm_fault_and_lock(struct mshv_region *region, * leaving missing pages as invalid PFN markers. * Used for initial region setup. * - * Collected PFNs are stored in region->mreg_pfns[] with HMM bookkeeping - * flags cleared, then the range is mapped into the hypervisor. Present - * PFNs get mapped with region access permissions; missing PFNs (invalid - * entries) get mapped with no-access permissions. + * HMM bookkeeping flags are stripped from collected PFNs before mapping. + * Present PFNs get mapped with region access permissions; missing PFNs + * (marked as MSHV_INVALID_PFN) get mapped with no-access permissions. * * Return: 0 on success, negative errno on failure. */ @@ -666,20 +663,24 @@ static int mshv_region_collect_and_map(struct mshv_region *region, goto out; for (i = 0; i < pfn_count; i++) { - if (!(pfns[i] & HMM_PFN_VALID)) + if (!(pfns[i] & HMM_PFN_VALID)) { + pfns[i] = MSHV_INVALID_PFN; continue; + } /* Skip read-only pages to avoid bypassing COW */ if (!do_fault && (region->hv_map_flags & HV_MAP_GPA_WRITABLE) && - !(pfns[i] & HMM_PFN_WRITE)) + !(pfns[i] & HMM_PFN_WRITE)) { + pfns[i] = MSHV_INVALID_PFN; continue; + } /* Drop HMM_PFN_* flags to ensure PFNs are valid. */ - region->mreg_pfns[pfn_offset + i] = pfns[i] & ~HMM_PFN_FLAGS; + pfns[i] &= ~HMM_PFN_FLAGS; } ret = mshv_region_remap_pfns(region, region->hv_map_flags, pfn_offset, pfn_count, - region->mreg_pfns + pfn_offset); + pfns); mutex_unlock(®ion->mreg_mutex); out: @@ -781,8 +782,6 @@ static bool mshv_region_interval_invalidate(struct mmu_interval_notifier *mni, if (ret) goto out_unlock; - mshv_region_invalidate_pfns(region, pfn_offset, pfn_count); - mutex_unlock(®ion->mreg_mutex); return true; @@ -845,7 +844,9 @@ static int mshv_map_pinned_region(struct mshv_region *region) } } - ret = mshv_region_map(region); + ret = mshv_region_remap_pfns(region, region->hv_map_flags, + 0, region->nr_pfns, + region->mreg_pfns); if (ret) goto share_region; @@ -869,7 +870,7 @@ static int mshv_map_pinned_region(struct mshv_region *region) * is intentional; unpinning host-inaccessible pages would be * unsafe). */ - mshv_region_init_pfns(region); + mshv_region_init_pfns(region, 0, region->nr_pfns); goto err_out; } err_out: diff --git a/drivers/hv/mshv_root.h b/drivers/hv/mshv_root.h index e9bd18013b486..d79dfaac88af9 100644 --- a/drivers/hv/mshv_root.h +++ b/drivers/hv/mshv_root.h @@ -93,8 +93,10 @@ struct mshv_region { enum mshv_region_type mreg_type; struct mmu_interval_notifier mreg_mni; struct mutex mreg_mutex; /* protects region PFNs remapping */ - u64 mreg_mmio_pfn; - unsigned long mreg_pfns[]; + union { + unsigned long *mreg_pfns; + u64 mreg_mmio_pfn; + }; }; struct mshv_irq_ack_notifier { @@ -375,7 +377,6 @@ struct mshv_region *mshv_region_create(struct mshv_partition *partition, u64 guest_pfn, u64 nr_pfns, u64 uaddr, u32 flags, unsigned long mmio_pfn); -void mshv_region_init_pfns(struct mshv_region *region); void mshv_region_put(struct mshv_region *region); int mshv_region_get(struct mshv_region *region); bool mshv_region_handle_gfn_fault(struct mshv_region *region, u64 gfn);

