mshv_region_hmm_fault_and_lock() called hmm_range_fault() once for the
entire requested range. hmm_range_fault() can only handle a single VMA per
call, so a region whose user address range spans multiple VMAs fails the
fault even though each individual VMA is fault-able.

Walk the requested range VMA by VMA under mmap_read_lock and call
hmm_range_fault() for each [vma->vm_start, vma->vm_end) ∩ [start, end)
segment. The mmu notifier sequence is captured once before the loop, so a
writer racing with the multi-VMA fault is still detected at the closing
mmu_interval_read_retry().

Tighten the read-only gate added in 3f8e229cb787 ("mshv: Don't request HMM
write fault for read-only regions") so HMM_PFN_REQ_WRITE is requested only
when both the region (HV_MAP_GPA_WRITABLE) and the backing VMA (VM_WRITE)
permit writes. Without the per-VMA check, a writable region whose
underlying VMA is read-only would still trigger COW on the host's read-only
pages.

While here, restructure mshv_region_hmm_fault_and_lock() to take the range
as (start, end, pfns) directly rather than a populated hmm_range; the
struct is now constructed inside the function since its fields are
recomputed per VMA.

Signed-off-by: Stanislav Kinsburskii <[email protected]>
---
 drivers/hv/mshv_regions.c |   97 +++++++++++++++++++++++++++++----------------
 1 file changed, 62 insertions(+), 35 deletions(-)

diff --git a/drivers/hv/mshv_regions.c b/drivers/hv/mshv_regions.c
index 579a29f2924b8..807fff43deb43 100644
--- a/drivers/hv/mshv_regions.c
+++ b/drivers/hv/mshv_regions.c
@@ -447,37 +447,76 @@ int mshv_region_get(struct mshv_mem_region *region)
 }
 
 /**
- * mshv_region_hmm_fault_and_lock - Handle HMM faults and lock the memory 
region
+ * mshv_region_hmm_fault_and_lock - Fault in pages across VMAs and lock
+ *                                  the memory region
  * @region: Pointer to the memory region structure
- * @range: Pointer to the HMM range structure
+ * @start : Starting virtual address of the range to fault (inclusive)
+ * @end   : Ending virtual address of the range to fault (exclusive)
+ * @pfns  : Output array for page frame numbers with HMM flags
  *
- * This function performs the following steps:
- * 1. Reads the notifier sequence for the HMM range.
- * 2. Acquires a read lock on the memory map.
- * 3. Handles HMM faults for the specified range.
- * 4. Releases the read lock on the memory map.
- * 5. If successful, locks the memory region mutex.
- * 6. Verifies if the notifier sequence has changed during the operation.
- *    If it has, releases the mutex and returns -EBUSY to match with
- *    hmm_range_fault() return code for repeating.
+ * Iterates through VMAs covering [start, end), faulting in pages via
+ * hmm_range_fault() for each VMA segment.  Write faults are requested
+ * only when both the VMA and the hypervisor mapping permit writes, to
+ * avoid breaking copy-on-write semantics on read-only mappings.
  *
- * Return: 0 on success, a negative error code otherwise.
+ * On success, returns with region->mreg_mutex held; the caller is
+ * responsible for releasing it.  Returns -EBUSY if the mmu notifier
+ * sequence changed during the operation, signalling the caller to retry.
+ *
+ * Return: 0 on success, negative error code on failure.
  */
 static int mshv_region_hmm_fault_and_lock(struct mshv_mem_region *region,
-                                         struct hmm_range *range)
+                                         unsigned long start,
+                                         unsigned long end,
+                                         unsigned long *pfns)
 {
+       struct hmm_range range = {
+               .notifier = &region->mreg_mni,
+       };
+       struct mm_struct *mm = region->mreg_mni.mm;
        int ret;
 
-       range->notifier_seq = mmu_interval_read_begin(range->notifier);
-       mmap_read_lock(region->mreg_mni.mm);
-       ret = hmm_range_fault(range);
-       mmap_read_unlock(region->mreg_mni.mm);
+       range.notifier_seq = mmu_interval_read_begin(range.notifier);
+       mmap_read_lock(mm);
+       while (start < end) {
+               struct vm_area_struct *vma;
+
+               vma = vma_lookup(mm, start);
+               if (!vma) {
+                       ret = -EFAULT;
+                       break;
+               }
+
+               range.hmm_pfns = pfns;
+               range.start = start;
+               range.end = min(vma->vm_end, end);
+               range.default_flags = HMM_PFN_REQ_FAULT;
+               /*
+                * Only request writable pages from HMM when both the
+                * VMA and the hypervisor mapping allow writes.  Without
+                * this, hmm_range_fault() would trigger COW on read-only
+                * mappings (e.g. shared zero pages, file-backed pages),
+                * breaking copy-on-write semantics and potentially
+                * granting the guest write access to shared host pages.
+                */
+               if ((vma->vm_flags & VM_WRITE) &&
+                   (region->hv_map_flags & HV_MAP_GPA_WRITABLE))
+                       range.default_flags |= HMM_PFN_REQ_WRITE;
+
+               ret = hmm_range_fault(&range);
+               if (ret)
+                       break;
+
+               start = range.end;
+               pfns += (range.end - range.start) >> PAGE_SHIFT;
+       }
+       mmap_read_unlock(mm);
        if (ret)
                return ret;
 
        mutex_lock(&region->mreg_mutex);
 
-       if (mmu_interval_read_retry(range->notifier, range->notifier_seq)) {
+       if (mmu_interval_read_retry(range.notifier, range.notifier_seq)) {
                mutex_unlock(&region->mreg_mutex);
                cond_resched();
                return -EBUSY;
@@ -501,10 +540,7 @@ static int mshv_region_hmm_fault_and_lock(struct 
mshv_mem_region *region,
 static int mshv_region_range_fault(struct mshv_mem_region *region,
                                   u64 pfn_offset, u64 pfn_count)
 {
-       struct hmm_range range = {
-               .notifier = &region->mreg_mni,
-               .default_flags = HMM_PFN_REQ_FAULT,
-       };
+       unsigned long start, end;
        unsigned long *pfns;
        int ret;
        u64 i;
@@ -513,21 +549,12 @@ static int mshv_region_range_fault(struct mshv_mem_region 
*region,
        if (!pfns)
                return -ENOMEM;
 
-       range.hmm_pfns = pfns;
-       range.start = region->start_uaddr + pfn_offset * HV_HYP_PAGE_SIZE;
-       range.end = range.start + pfn_count * HV_HYP_PAGE_SIZE;
-
-       /*
-        * Only request writable pages from HMM when the region itself
-        * permits writes.  Without this, hmm_range_fault() would
-        * trigger COW on read-only regions, breaking copy-on-write
-        * semantics on shared host pages.
-        */
-       if (region->hv_map_flags & HV_MAP_GPA_WRITABLE)
-               range.default_flags |= HMM_PFN_REQ_WRITE;
+       start = region->start_uaddr + pfn_offset * HV_HYP_PAGE_SIZE;
+       end = start + pfn_count * HV_HYP_PAGE_SIZE;
 
        do {
-               ret = mshv_region_hmm_fault_and_lock(region, &range);
+               ret = mshv_region_hmm_fault_and_lock(region, start, end,
+                                                    pfns);
        } while (ret == -EBUSY);
 
        if (ret)



Reply via email to