... by factoring it out from track_pfn_remap().

For PMDs/PUDs, actually check the full range, and trigger a fallback
if we run into this "different memory types / cachemodes" scenario.

Add some documentation.

Will checking each page result in undesired overhead? We'll have to
learn. Not checking each page looks wrong, though. Maybe we could
optimize the lookup internally.

Signed-off-by: David Hildenbrand <da...@redhat.com>
---
 arch/x86/mm/pat/memtype.c | 24 ++++++++----------------
 include/linux/pgtable.h   | 28 ++++++++++++++++++++--------
 mm/huge_memory.c          |  7 +++++--
 mm/memory.c               |  4 ++--
 4 files changed, 35 insertions(+), 28 deletions(-)

diff --git a/arch/x86/mm/pat/memtype.c b/arch/x86/mm/pat/memtype.c
index edec5859651d6..193e33251b18f 100644
--- a/arch/x86/mm/pat/memtype.c
+++ b/arch/x86/mm/pat/memtype.c
@@ -1031,7 +1031,6 @@ int track_pfn_remap(struct vm_area_struct *vma, pgprot_t 
*prot,
                    unsigned long pfn, unsigned long addr, unsigned long size)
 {
        resource_size_t paddr = (resource_size_t)pfn << PAGE_SHIFT;
-       enum page_cache_mode pcm;
 
        /* reserve the whole chunk starting from paddr */
        if (!vma || (addr == vma->vm_start
@@ -1044,13 +1043,17 @@ int track_pfn_remap(struct vm_area_struct *vma, 
pgprot_t *prot,
                return ret;
        }
 
+       return pfnmap_sanitize_pgprot(pfn, size, prot);
+}
+
+int pfnmap_sanitize_pgprot(unsigned long pfn, unsigned long size, pgprot_t 
*prot)
+{
+       resource_size_t paddr = (resource_size_t)pfn << PAGE_SHIFT;
+       enum page_cache_mode pcm;
+
        if (!pat_enabled())
                return 0;
 
-       /*
-        * For anything smaller than the vma size we set prot based on the
-        * lookup.
-        */
        pcm = lookup_memtype(paddr);
 
        /* Check memtype for the remaining pages */
@@ -1065,17 +1068,6 @@ int track_pfn_remap(struct vm_area_struct *vma, pgprot_t 
*prot,
        return 0;
 }
 
-void track_pfn_insert(struct vm_area_struct *vma, pgprot_t *prot, pfn_t pfn)
-{
-       enum page_cache_mode pcm;
-
-       if (!pat_enabled())
-               return;
-
-       pcm = lookup_memtype(pfn_t_to_phys(pfn));
-       pgprot_set_cachemode(prot, pcm);
-}
-
 /*
  * untrack_pfn is called while unmapping a pfnmap for a region.
  * untrack can be called for a specific region indicated by pfn and size or
diff --git a/include/linux/pgtable.h b/include/linux/pgtable.h
index b50447ef1c921..91aadfe2515a5 100644
--- a/include/linux/pgtable.h
+++ b/include/linux/pgtable.h
@@ -1500,13 +1500,10 @@ static inline int track_pfn_remap(struct vm_area_struct 
*vma, pgprot_t *prot,
        return 0;
 }
 
-/*
- * track_pfn_insert is called when a _new_ single pfn is established
- * by vmf_insert_pfn().
- */
-static inline void track_pfn_insert(struct vm_area_struct *vma, pgprot_t *prot,
-                                   pfn_t pfn)
+static inline int pfnmap_sanitize_pgprot(unsigned long pfn, unsigned long size,
+               pgprot_t *prot)
 {
+       return 0;
 }
 
 /*
@@ -1556,8 +1553,23 @@ static inline void untrack_pfn_clear(struct 
vm_area_struct *vma)
 extern int track_pfn_remap(struct vm_area_struct *vma, pgprot_t *prot,
                           unsigned long pfn, unsigned long addr,
                           unsigned long size);
-extern void track_pfn_insert(struct vm_area_struct *vma, pgprot_t *prot,
-                            pfn_t pfn);
+
+/**
+ * pfnmap_sanitize_pgprot - sanitize the pgprot for a pfn range
+ * @pfn: the start of the pfn range
+ * @size: the size of the pfn range
+ * @prot: the pgprot to sanitize
+ *
+ * Sanitize the given pgprot for a pfn range, for example, adjusting the
+ * cachemode.
+ *
+ * This function cannot fail for a single page, but can fail for multiple
+ * pages.
+ *
+ * Returns 0 on success and -EINVAL on error.
+ */
+int pfnmap_sanitize_pgprot(unsigned long pfn, unsigned long size,
+               pgprot_t *prot);
 extern int track_pfn_copy(struct vm_area_struct *dst_vma,
                struct vm_area_struct *src_vma, unsigned long *pfn);
 extern void untrack_pfn_copy(struct vm_area_struct *dst_vma,
diff --git a/mm/huge_memory.c b/mm/huge_memory.c
index fdcf0a6049b9f..b8ae5e1493315 100644
--- a/mm/huge_memory.c
+++ b/mm/huge_memory.c
@@ -1455,7 +1455,9 @@ vm_fault_t vmf_insert_pfn_pmd(struct vm_fault *vmf, pfn_t 
pfn, bool write)
                        return VM_FAULT_OOM;
        }
 
-       track_pfn_insert(vma, &pgprot, pfn);
+       if (pfnmap_sanitize_pgprot(pfn_t_to_pfn(pfn), PAGE_SIZE, &pgprot))
+               return VM_FAULT_FALLBACK;
+
        ptl = pmd_lock(vma->vm_mm, vmf->pmd);
        error = insert_pfn_pmd(vma, addr, vmf->pmd, pfn, pgprot, write,
                        pgtable);
@@ -1577,7 +1579,8 @@ vm_fault_t vmf_insert_pfn_pud(struct vm_fault *vmf, pfn_t 
pfn, bool write)
        if (addr < vma->vm_start || addr >= vma->vm_end)
                return VM_FAULT_SIGBUS;
 
-       track_pfn_insert(vma, &pgprot, pfn);
+       if (pfnmap_sanitize_pgprot(pfn_t_to_pfn(pfn), PAGE_SIZE, &pgprot))
+               return VM_FAULT_FALLBACK;
 
        ptl = pud_lock(vma->vm_mm, vmf->pud);
        insert_pfn_pud(vma, addr, vmf->pud, pfn, write);
diff --git a/mm/memory.c b/mm/memory.c
index 424420349bd3c..c737a8625866a 100644
--- a/mm/memory.c
+++ b/mm/memory.c
@@ -2563,7 +2563,7 @@ vm_fault_t vmf_insert_pfn_prot(struct vm_area_struct 
*vma, unsigned long addr,
        if (!pfn_modify_allowed(pfn, pgprot))
                return VM_FAULT_SIGBUS;
 
-       track_pfn_insert(vma, &pgprot, __pfn_to_pfn_t(pfn, PFN_DEV));
+       pfnmap_sanitize_pgprot(pfn, PAGE_SIZE, &pgprot);
 
        return insert_pfn(vma, addr, __pfn_to_pfn_t(pfn, PFN_DEV), pgprot,
                        false);
@@ -2626,7 +2626,7 @@ static vm_fault_t __vm_insert_mixed(struct vm_area_struct 
*vma,
        if (addr < vma->vm_start || addr >= vma->vm_end)
                return VM_FAULT_SIGBUS;
 
-       track_pfn_insert(vma, &pgprot, pfn);
+       pfnmap_sanitize_pgprot(pfn_t_to_pfn(pfn), PAGE_SIZE, &pgprot);
 
        if (!pfn_modify_allowed(pfn_t_to_pfn(pfn), pgprot))
                return VM_FAULT_SIGBUS;
-- 
2.49.0

Reply via email to