mshv_region_process_pfns() conflated three concerns: validating the first
PFN of a chunk, locating the longest contiguous run of same-stride PFNs
starting from there, and dispatching the chunk to the handler. The
locate-and-dispatch interleaving made the partial-consume case (4K-to-2M
stride transition inside a same-validity run) emergent rather than
explicit, and required process_range to handle a return value that was
simultaneously a count and an error code.

Split the locate step out into mshv_region_chunk_size().  The new helper
takes a starting offset and an upper bound, returns the length of the
same-stride run, and reports whether that run is huge-page-backed via an
out-parameter. mshv_region_process_pfns() goes away;
mshv_region_process_range() now drives the loop directly, calling
chunk_size() for the next segment length and dispatching the handler with
the precomputed huge_page hint.

mshv_chunk_stride() additionally takes a PFN instead of a struct page * and
validates it internally, so each call site no longer needs its own
mshv_pfn_valid() check before pfn_to_page().

No functional change; the per-handler dispatch shape, segmentation
boundaries, and lock context are all preserved.

Signed-off-by: Stanislav Kinsburskii <[email protected]>
---
 drivers/hv/mshv_regions.c |  103 ++++++++++++++++++++++-----------------------
 1 file changed, 50 insertions(+), 53 deletions(-)

diff --git a/drivers/hv/mshv_regions.c b/drivers/hv/mshv_regions.c
index 77fc94733cb20..090c4052f0f4d 100644
--- a/drivers/hv/mshv_regions.c
+++ b/drivers/hv/mshv_regions.c
@@ -41,7 +41,7 @@ void mshv_region_init_pfns(struct mshv_mem_region *region)
 
 /**
  * mshv_chunk_stride - Compute stride for mapping guest memory
- * @page     : The page to check for huge page backing
+ * @pfn      : The PFN to check for huge page backing
  * @gfn      : Guest frame number for the mapping
  * @pfn_count: Total number of pages in the mapping
  *
@@ -51,11 +51,16 @@ void mshv_region_init_pfns(struct mshv_mem_region *region)
  *
  * Return: Stride in pages, or -EINVAL if page order is unsupported.
  */
-static int mshv_chunk_stride(struct page *page,
-                            u64 gfn, u64 pfn_count)
+static int mshv_chunk_stride(unsigned long pfn, u64 gfn, u64 pfn_count)
 {
+       struct page *page;
        unsigned int page_order;
 
+       if (!mshv_pfn_valid(pfn))
+               return -EINVAL;
+
+       page = pfn_to_page(pfn);
+
        /*
         * Use single page stride by default. For huge page stride, the
         * page must be compound and point to the head of the compound
@@ -76,65 +81,51 @@ static int mshv_chunk_stride(struct page *page,
 }
 
 /**
- * mshv_region_process_chunk - Processes a contiguous chunk of memory pages
- *                             in a region.
- * @region    : Pointer to the memory region structure.
- * @flags     : Flags to pass to the handler.
- * @pfn_offset: Offset into the region's PFNs array to start processing.
- * @pfn_count : Number of PFNs to process.
- * @handler   : Callback function to handle the chunk.
+ * mshv_region_chunk_size - Length of the next same-stride PFN run.
+ * @region    : Memory region whose PFN array is being walked.
+ * @pfn_offset: Offset into region->mreg_pfns at which to start; the
+ *              PFN at this offset must be valid.
+ * @pfn_count : Upper bound on the run length (not necessarily the
+ *              region's total length; typically the residual passed
+ *              from mshv_region_process_range()).
+ * @huge_page : Out-parameter set to true if the run is backed by
+ *              PMD-order folios and may be dispatched as 2 MiB
+ *              chunks; false for 4 KiB-stride dispatch.
  *
- * This function scans the region's PFNs starting from @pfn_offset,
- * checking for contiguous valid PFNs backed by pages of the same size
- * (normal or huge). It invokes @handler for the chunk of contiguous valid
- * PFNs found. Returns the number of PFNs handled, or a negative error code
- * if the first PFN is invalid or the handler fails.
+ * Walks the PFN array starting at @pfn_offset and returns the length
+ * of the longest contiguous run that shares the stride classification
+ * (4 KiB vs 2 MiB) of the first PFN.  An invalid PFN inside the run
+ * terminates it.  The run is bounded above by @pfn_count.
  *
- * Note: The @handler callback must be able to handle valid PFNs backed by
- * both normal and huge pages.
+ * The caller may then dispatch [pfn_offset, pfn_offset + return) to a
+ * handler with @huge_page indicating which stride applies.  After the
+ * dispatch the caller advances by the returned length and re-invokes
+ * this function for the next run.
  *
- * Return: Number of pages handled, or negative error code.
+ * Return: Length of the run in PFNs, or a negative errno from
+ *         mshv_chunk_stride() if the starting PFN is invalid or its
+ *         backing folio order is unsupported.
  */
-static long mshv_region_process_pfns(struct mshv_mem_region *region,
-                                    u32 flags,
-                                    u64 pfn_offset, u64 pfn_count,
-                                    int (*handler)(struct mshv_mem_region 
*region,
-                                                   u32 flags,
-                                                   u64 pfn_offset,
-                                                   u64 pfn_count,
-                                                   bool huge_page))
+static long mshv_region_chunk_size(struct mshv_mem_region *region,
+                                  u64 pfn_offset, u64 pfn_count,
+                                  bool *huge_page)
 {
+       unsigned long *pfns = region->mreg_pfns + pfn_offset;
        u64 gfn = region->start_gfn + pfn_offset;
-       u64 count;
-       unsigned long pfn;
-       int stride, ret;
+       u64 count = 0, stride;
 
-       pfn = region->mreg_pfns[pfn_offset];
-       if (!mshv_pfn_valid(pfn))
-               return -EINVAL;
-
-       stride = mshv_chunk_stride(pfn_to_page(pfn), gfn, pfn_count);
+       stride = mshv_chunk_stride(pfns[0], gfn, pfn_count);
        if (stride < 0)
                return stride;
 
-       /* Start at stride since the first stride is validated */
-       for (count = stride; count < pfn_count ; count += stride) {
-               pfn = region->mreg_pfns[pfn_offset + count];
-
-               /* Break if current pfn is invalid */
-               if (!mshv_pfn_valid(pfn))
-                       break;
-
-               /* Break if stride size changes */
-               if (stride != mshv_chunk_stride(pfn_to_page(pfn),
+       for (count = stride; count < pfn_count; count += stride) {
+               if (stride != mshv_chunk_stride(pfns[count],
                                                gfn + count,
                                                pfn_count - count))
                        break;
        }
 
-       ret = handler(region, flags, pfn_offset, count, stride > 1);
-       if (ret)
-               return ret;
+       *huge_page = stride > 1;
 
        return count;
 }
@@ -150,7 +141,7 @@ static long mshv_region_process_pfns(struct mshv_mem_region 
*region,
  *
  * Iterates over the specified range of PFNs in @region, skipping
  * invalid PFNs. For each contiguous chunk of valid PFNS, invokes
- * @handler via mshv_region_process_pfns.
+ * @handler.
  *
  * Note: The @handler callback must be able to handle PFNs backed by both
  * normal and huge pages.
@@ -176,6 +167,9 @@ static int mshv_region_process_range(struct mshv_mem_region 
*region,
                return -EINVAL;
 
        while (pfn_count) {
+               bool huge_page;
+               long count;
+
                /* Skip non-present pages */
                if (!mshv_pfn_valid(region->mreg_pfns[pfn_offset])) {
                        pfn_offset++;
@@ -183,14 +177,17 @@ static int mshv_region_process_range(struct 
mshv_mem_region *region,
                        continue;
                }
 
-               ret = mshv_region_process_pfns(region, flags,
-                                              pfn_offset, pfn_count,
-                                              handler);
+               count = mshv_region_chunk_size(region, pfn_offset, pfn_count,
+                                              &huge_page);
+               if (count < 0)
+                       return count;
+
+               ret = handler(region, flags, pfn_offset, count, huge_page);
                if (ret < 0)
                        return ret;
 
-               pfn_offset += ret;
-               pfn_count -= ret;
+               pfn_offset += count;
+               pfn_count -= count;
        }
 
        return 0;



Reply via email to