nmi_uaccess_okay() emits a warning if current CR3 != mm->pgd.
Limit the warning to only when ASI is not active.

Co-developed-by: Junaid Shahid <juna...@google.com>
Signed-off-by: Junaid Shahid <juna...@google.com>
Co-developed-by: Yosry Ahmed <yosryah...@google.com>
Signed-off-by: Yosry Ahmed <yosryah...@google.com>
Signed-off-by: Brendan Jackman <jackm...@google.com>
---
 arch/x86/mm/tlb.c | 26 +++++++++++++++++++++-----
 1 file changed, 21 insertions(+), 5 deletions(-)

diff --git a/arch/x86/mm/tlb.c b/arch/x86/mm/tlb.c
index 
b2a13fdab0c6454c1d9d4e3338801f3402da4191..c41e083c5b5281684be79ad0391c1a5fc7b0c493
 100644
--- a/arch/x86/mm/tlb.c
+++ b/arch/x86/mm/tlb.c
@@ -1340,6 +1340,22 @@ void arch_tlbbatch_flush(struct 
arch_tlbflush_unmap_batch *batch)
        put_cpu();
 }
 
+static inline bool cr3_matches_current_mm(void)
+{
+       struct asi *asi = asi_get_current();
+       pgd_t *pgd_asi = asi_pgd(asi);
+       pgd_t *pgd_cr3;
+
+       /*
+        * Prevent read_cr3_pa -> [NMI, asi_exit] -> asi_get_current,
+        * otherwise we might find CR3 pointing to the ASI PGD but not
+        * find a current ASI domain.
+        */
+       barrier();
+       pgd_cr3 = __va(read_cr3_pa());
+       return pgd_cr3 == current->mm->pgd || pgd_cr3 == pgd_asi;
+}
+
 /*
  * Blindly accessing user memory from NMI context can be dangerous
  * if we're in the middle of switching the current user task or
@@ -1355,10 +1371,10 @@ bool nmi_uaccess_okay(void)
        VM_WARN_ON_ONCE(!loaded_mm);
 
        /*
-        * The condition we want to check is
-        * current_mm->pgd == __va(read_cr3_pa()).  This may be slow, though,
-        * if we're running in a VM with shadow paging, and nmi_uaccess_okay()
-        * is supposed to be reasonably fast.
+        * The condition we want to check that CR3 points to either
+        * current_mm->pgd or an appropriate ASI PGD. Reading CR3 may be slow,
+        * though, if we're running in a VM with shadow paging, and
+        * nmi_uaccess_okay() is supposed to be reasonably fast.
         *
         * Instead, we check the almost equivalent but somewhat conservative
         * condition below, and we rely on the fact that switch_mm_irqs_off()
@@ -1367,7 +1383,7 @@ bool nmi_uaccess_okay(void)
        if (loaded_mm != current_mm)
                return false;
 
-       VM_WARN_ON_ONCE(current_mm->pgd != __va(read_cr3_pa()));
+       VM_WARN_ON_ONCE(!cr3_matches_current_mm());
 
        return true;
 }

-- 
2.47.1.613.gc27f4b7a9f-goog


Reply via email to