Update the vma_modify_flags() and vma_modify_flags_uffd() functions to
accept a vma_flags_t parameter rather than a vm_flags_t one, and propagate
the changes as needed to implement this change.

Finally, update the VMA tests to reflect this.

Signed-off-by: Lorenzo Stoakes (Oracle) <[email protected]>
---
 include/linux/userfaultfd_k.h   |  3 +++
 mm/madvise.c                    | 10 ++++----
 mm/mlock.c                      | 42 ++++++++++++++++++++-------------
 mm/mprotect.c                   |  7 +++---
 mm/mseal.c                      | 10 ++++----
 mm/userfaultfd.c                | 21 +++++++++++------
 mm/vma.c                        | 15 ++++++------
 mm/vma.h                        | 15 ++++++------
 tools/testing/vma/tests/merge.c |  3 +--
 9 files changed, 73 insertions(+), 53 deletions(-)

diff --git a/include/linux/userfaultfd_k.h b/include/linux/userfaultfd_k.h
index bf4e595ac914..3bd2003328dc 100644
--- a/include/linux/userfaultfd_k.h
+++ b/include/linux/userfaultfd_k.h
@@ -23,6 +23,9 @@
 /* The set of all possible UFFD-related VM flags. */
 #define __VM_UFFD_FLAGS (VM_UFFD_MISSING | VM_UFFD_WP | VM_UFFD_MINOR)
 
+#define __VMA_UFFD_FLAGS mk_vma_flags(VMA_UFFD_MISSING_BIT, VMA_UFFD_WP_BIT, \
+                                     VMA_UFFD_MINOR_BIT)
+
 /*
  * CAREFUL: Check include/uapi/asm-generic/fcntl.h when defining
  * new flags, since they might collide with O_* ones. We want
diff --git a/mm/madvise.c b/mm/madvise.c
index afe0f01765c4..69708e953cf5 100644
--- a/mm/madvise.c
+++ b/mm/madvise.c
@@ -151,13 +151,15 @@ static int madvise_update_vma(vm_flags_t new_flags,
                struct madvise_behavior *madv_behavior)
 {
        struct vm_area_struct *vma = madv_behavior->vma;
+       vma_flags_t new_vma_flags = legacy_to_vma_flags(new_flags);
        struct madvise_behavior_range *range = &madv_behavior->range;
        struct anon_vma_name *anon_name = madv_behavior->anon_name;
        bool set_new_anon_name = madv_behavior->behavior == 
__MADV_SET_ANON_VMA_NAME;
        VMA_ITERATOR(vmi, madv_behavior->mm, range->start);
 
-       if (new_flags == vma->vm_flags && (!set_new_anon_name ||
-                       anon_vma_name_eq(anon_vma_name(vma), anon_name)))
+       if (vma_flags_same_mask(&vma->flags, new_vma_flags) &&
+           (!set_new_anon_name ||
+            anon_vma_name_eq(anon_vma_name(vma), anon_name)))
                return 0;
 
        if (set_new_anon_name)
@@ -165,7 +167,7 @@ static int madvise_update_vma(vm_flags_t new_flags,
                        range->start, range->end, anon_name);
        else
                vma = vma_modify_flags(&vmi, madv_behavior->prev, vma,
-                       range->start, range->end, &new_flags);
+                       range->start, range->end, &new_vma_flags);
 
        if (IS_ERR(vma))
                return PTR_ERR(vma);
@@ -174,7 +176,7 @@ static int madvise_update_vma(vm_flags_t new_flags,
 
        /* vm_flags is protected by the mmap_lock held in write mode. */
        vma_start_write(vma);
-       vm_flags_reset(vma, new_flags);
+       vma->flags = new_vma_flags;
        if (set_new_anon_name)
                return replace_anon_vma_name(vma, anon_name);
 
diff --git a/mm/mlock.c b/mm/mlock.c
index c980630afd0d..b4dbf87b0575 100644
--- a/mm/mlock.c
+++ b/mm/mlock.c
@@ -415,13 +415,14 @@ static int mlock_pte_range(pmd_t *pmd, unsigned long addr,
  * @vma - vma containing range to be mlock()ed or munlock()ed
  * @start - start address in @vma of the range
  * @end - end of range in @vma
- * @newflags - the new set of flags for @vma.
+ * @new_vma_flags - the new set of flags for @vma.
  *
  * Called for mlock(), mlock2() and mlockall(), to set @vma VM_LOCKED;
  * called for munlock() and munlockall(), to clear VM_LOCKED from @vma.
  */
 static void mlock_vma_pages_range(struct vm_area_struct *vma,
-       unsigned long start, unsigned long end, vm_flags_t newflags)
+       unsigned long start, unsigned long end,
+       vma_flags_t *new_vma_flags)
 {
        static const struct mm_walk_ops mlock_walk_ops = {
                .pmd_entry = mlock_pte_range,
@@ -439,18 +440,18 @@ static void mlock_vma_pages_range(struct vm_area_struct 
*vma,
         * combination should not be visible to other mmap_lock users;
         * but WRITE_ONCE so rmap walkers must see VM_IO if VM_LOCKED.
         */
-       if (newflags & VM_LOCKED)
-               newflags |= VM_IO;
+       if (vma_flags_test(new_vma_flags, VMA_LOCKED_BIT))
+               vma_flags_set(new_vma_flags, VMA_IO_BIT);
        vma_start_write(vma);
-       vm_flags_reset_once(vma, newflags);
+       WRITE_ONCE(vma->flags, *new_vma_flags);
 
        lru_add_drain();
        walk_page_range(vma->vm_mm, start, end, &mlock_walk_ops, NULL);
        lru_add_drain();
 
-       if (newflags & VM_IO) {
-               newflags &= ~VM_IO;
-               vm_flags_reset_once(vma, newflags);
+       if (vma_flags_test(new_vma_flags, VMA_IO_BIT)) {
+               vma_flags_clear(new_vma_flags, VMA_IO_BIT);
+               WRITE_ONCE(vma->flags, *new_vma_flags);
        }
 }
 
@@ -467,18 +468,24 @@ static int mlock_fixup(struct vma_iterator *vmi, struct 
vm_area_struct *vma,
               struct vm_area_struct **prev, unsigned long start,
               unsigned long end, vm_flags_t newflags)
 {
+       vma_flags_t new_vma_flags = legacy_to_vma_flags(newflags);
+       const vma_flags_t old_vma_flags = vma->flags;
        struct mm_struct *mm = vma->vm_mm;
        int nr_pages;
        int ret = 0;
-       vm_flags_t oldflags = vma->vm_flags;
 
-       if (newflags == oldflags || (oldflags & VM_SPECIAL) ||
+
+
+       if (vma_flags_same_pair(&old_vma_flags, &new_vma_flags) ||
+           vma_flags_test_any_mask(&old_vma_flags, VMA_SPECIAL_FLAGS) ||
            is_vm_hugetlb_page(vma) || vma == get_gate_vma(current->mm) ||
-           vma_is_dax(vma) || vma_is_secretmem(vma) || (oldflags & 
VM_DROPPABLE))
+           vma_is_dax(vma) || vma_is_secretmem(vma) ||
+           vma_flags_test_any_mask(&old_vma_flags, VMA_DROPPABLE)) {
                /* don't set VM_LOCKED or VM_LOCKONFAULT and don't count */
                goto out;
+       }
 
-       vma = vma_modify_flags(vmi, *prev, vma, start, end, &newflags);
+       vma = vma_modify_flags(vmi, *prev, vma, start, end, &new_vma_flags);
        if (IS_ERR(vma)) {
                ret = PTR_ERR(vma);
                goto out;
@@ -488,9 +495,9 @@ static int mlock_fixup(struct vma_iterator *vmi, struct 
vm_area_struct *vma,
         * Keep track of amount of locked VM.
         */
        nr_pages = (end - start) >> PAGE_SHIFT;
-       if (!(newflags & VM_LOCKED))
+       if (!vma_flags_test(&new_vma_flags, VMA_LOCKED_BIT))
                nr_pages = -nr_pages;
-       else if (oldflags & VM_LOCKED)
+       else if (vma_flags_test(&old_vma_flags, VMA_LOCKED_BIT))
                nr_pages = 0;
        mm->locked_vm += nr_pages;
 
@@ -499,12 +506,13 @@ static int mlock_fixup(struct vma_iterator *vmi, struct 
vm_area_struct *vma,
         * It's okay if try_to_unmap_one unmaps a page just after we
         * set VM_LOCKED, populate_vma_page_range will bring it back.
         */
-       if ((newflags & VM_LOCKED) && (oldflags & VM_LOCKED)) {
+       if (vma_flags_test(&new_vma_flags, VMA_LOCKED_BIT) &&
+           vma_flags_test(&old_vma_flags, VMA_LOCKED_BIT)) {
                /* No work to do, and mlocking twice would be wrong */
                vma_start_write(vma);
-               vm_flags_reset(vma, newflags);
+               vma->flags = new_vma_flags;
        } else {
-               mlock_vma_pages_range(vma, start, end, newflags);
+               mlock_vma_pages_range(vma, start, end, &new_vma_flags);
        }
 out:
        *prev = vma;
diff --git a/mm/mprotect.c b/mm/mprotect.c
index eaa724b99908..2b8a85689ab7 100644
--- a/mm/mprotect.c
+++ b/mm/mprotect.c
@@ -756,13 +756,11 @@ mprotect_fixup(struct vma_iterator *vmi, struct 
mmu_gather *tlb,
                vma_flags_clear(&new_vma_flags, VMA_ACCOUNT_BIT);
        }
 
-       newflags = vma_flags_to_legacy(new_vma_flags);
-       vma = vma_modify_flags(vmi, *pprev, vma, start, end, &newflags);
+       vma = vma_modify_flags(vmi, *pprev, vma, start, end, &new_vma_flags);
        if (IS_ERR(vma)) {
                error = PTR_ERR(vma);
                goto fail;
        }
-       new_vma_flags = legacy_to_vma_flags(newflags);
 
        *pprev = vma;
 
@@ -771,7 +769,7 @@ mprotect_fixup(struct vma_iterator *vmi, struct mmu_gather 
*tlb,
         * held in write mode.
         */
        vma_start_write(vma);
-       vm_flags_reset_once(vma, newflags);
+       WRITE_ONCE(vma->flags, new_vma_flags);
        if (vma_wants_manual_pte_write_upgrade(vma))
                mm_cp_flags |= MM_CP_TRY_CHANGE_WRITABLE;
        vma_set_page_prot(vma);
@@ -796,6 +794,7 @@ mprotect_fixup(struct vma_iterator *vmi, struct mmu_gather 
*tlb,
        }
 
        vm_stat_account(mm, vma_flags_to_legacy(old_vma_flags), -nrpages);
+       newflags = vma_flags_to_legacy(new_vma_flags);
        vm_stat_account(mm, newflags, nrpages);
        perf_event_mmap(vma);
        return 0;
diff --git a/mm/mseal.c b/mm/mseal.c
index 316b5e1dec78..fd299d60ad17 100644
--- a/mm/mseal.c
+++ b/mm/mseal.c
@@ -68,14 +68,16 @@ static int mseal_apply(struct mm_struct *mm,
        for_each_vma_range(vmi, vma, end) {
                const unsigned long curr_end = MIN(vma->vm_end, end);
 
-               if (!(vma->vm_flags & VM_SEALED)) {
-                       vm_flags_t vm_flags = vma->vm_flags | VM_SEALED;
+               if (!vma_test(vma, VMA_SEALED_BIT)) {
+                       vma_flags_t vma_flags = vma->flags;
+
+                       vma_flags_set(&vma_flags, VMA_SEALED_BIT);
 
                        vma = vma_modify_flags(&vmi, prev, vma, curr_start,
-                                              curr_end, &vm_flags);
+                                              curr_end, &vma_flags);
                        if (IS_ERR(vma))
                                return PTR_ERR(vma);
-                       vm_flags_set(vma, VM_SEALED);
+                       vma_set_flags(vma, VMA_SEALED_BIT);
                }
 
                prev = vma;
diff --git a/mm/userfaultfd.c b/mm/userfaultfd.c
index 849fb2f30233..9a93b77d3bed 100644
--- a/mm/userfaultfd.c
+++ b/mm/userfaultfd.c
@@ -2094,6 +2094,9 @@ struct vm_area_struct *userfaultfd_clear_vma(struct 
vma_iterator *vmi,
 {
        struct vm_area_struct *ret;
        bool give_up_on_oom = false;
+       vma_flags_t new_vma_flags = vma->flags;
+
+       vma_flags_clear_mask(&new_vma_flags, __VMA_UFFD_FLAGS);
 
        /*
         * If we are modifying only and not splitting, just give up on the merge
@@ -2107,8 +2110,8 @@ struct vm_area_struct *userfaultfd_clear_vma(struct 
vma_iterator *vmi,
                uffd_wp_range(vma, start, end - start, false);
 
        ret = vma_modify_flags_uffd(vmi, prev, vma, start, end,
-                                   vma->vm_flags & ~__VM_UFFD_FLAGS,
-                                   NULL_VM_UFFD_CTX, give_up_on_oom);
+                                   &new_vma_flags, NULL_VM_UFFD_CTX,
+                                   give_up_on_oom);
 
        /*
         * In the vma_merge() successful mprotect-like case 8:
@@ -2128,10 +2131,11 @@ int userfaultfd_register_range(struct userfaultfd_ctx 
*ctx,
                               unsigned long start, unsigned long end,
                               bool wp_async)
 {
+       vma_flags_t vma_flags = legacy_to_vma_flags(vm_flags);
        VMA_ITERATOR(vmi, ctx->mm, start);
        struct vm_area_struct *prev = vma_prev(&vmi);
        unsigned long vma_end;
-       vm_flags_t new_flags;
+       vma_flags_t new_vma_flags;
 
        if (vma->vm_start < start)
                prev = vma;
@@ -2142,23 +2146,26 @@ int userfaultfd_register_range(struct userfaultfd_ctx 
*ctx,
                VM_WARN_ON_ONCE(!vma_can_userfault(vma, vm_flags, wp_async));
                VM_WARN_ON_ONCE(vma->vm_userfaultfd_ctx.ctx &&
                                vma->vm_userfaultfd_ctx.ctx != ctx);
-               VM_WARN_ON_ONCE(!(vma->vm_flags & VM_MAYWRITE));
+               VM_WARN_ON_ONCE(!vma_test(vma, VMA_MAYWRITE_BIT));
 
                /*
                 * Nothing to do: this vma is already registered into this
                 * userfaultfd and with the right tracking mode too.
                 */
                if (vma->vm_userfaultfd_ctx.ctx == ctx &&
-                   (vma->vm_flags & vm_flags) == vm_flags)
+                   vma_test_all_mask(vma, vma_flags))
                        goto skip;
 
                if (vma->vm_start > start)
                        start = vma->vm_start;
                vma_end = min(end, vma->vm_end);
 
-               new_flags = (vma->vm_flags & ~__VM_UFFD_FLAGS) | vm_flags;
+               new_vma_flags = vma->flags;
+               vma_flags_clear_mask(&new_vma_flags, __VMA_UFFD_FLAGS);
+               vma_flags_set_mask(&new_vma_flags, vma_flags);
+
                vma = vma_modify_flags_uffd(&vmi, prev, vma, start, vma_end,
-                                           new_flags,
+                                           &new_vma_flags,
                                            (struct vm_userfaultfd_ctx){ctx},
                                            /* give_up_on_oom = */false);
                if (IS_ERR(vma))
diff --git a/mm/vma.c b/mm/vma.c
index 0fe4a161960e..c2c649b23465 100644
--- a/mm/vma.c
+++ b/mm/vma.c
@@ -1709,13 +1709,13 @@ static struct vm_area_struct *vma_modify(struct 
vma_merge_struct *vmg)
 struct vm_area_struct *vma_modify_flags(struct vma_iterator *vmi,
                struct vm_area_struct *prev, struct vm_area_struct *vma,
                unsigned long start, unsigned long end,
-               vm_flags_t *vm_flags_ptr)
+               vma_flags_t *vma_flags_ptr)
 {
        VMG_VMA_STATE(vmg, vmi, prev, vma, start, end);
-       const vm_flags_t vm_flags = *vm_flags_ptr;
+       const vma_flags_t vma_flags = *vma_flags_ptr;
        struct vm_area_struct *ret;
 
-       vmg.vm_flags = vm_flags;
+       vmg.vma_flags = vma_flags;
 
        ret = vma_modify(&vmg);
        if (IS_ERR(ret))
@@ -1727,7 +1727,7 @@ struct vm_area_struct *vma_modify_flags(struct 
vma_iterator *vmi,
         * them to the caller.
         */
        if (vmg.state == VMA_MERGE_SUCCESS)
-               *vm_flags_ptr = ret->vm_flags;
+               *vma_flags_ptr = ret->flags;
        return ret;
 }
 
@@ -1757,12 +1757,13 @@ struct vm_area_struct *vma_modify_policy(struct 
vma_iterator *vmi,
 
 struct vm_area_struct *vma_modify_flags_uffd(struct vma_iterator *vmi,
                struct vm_area_struct *prev, struct vm_area_struct *vma,
-               unsigned long start, unsigned long end, vm_flags_t vm_flags,
-               struct vm_userfaultfd_ctx new_ctx, bool give_up_on_oom)
+               unsigned long start, unsigned long end,
+               const vma_flags_t *vma_flags, struct vm_userfaultfd_ctx new_ctx,
+               bool give_up_on_oom)
 {
        VMG_VMA_STATE(vmg, vmi, prev, vma, start, end);
 
-       vmg.vm_flags = vm_flags;
+       vmg.vma_flags = *vma_flags;
        vmg.uffd_ctx = new_ctx;
        if (give_up_on_oom)
                vmg.give_up_on_oom = true;
diff --git a/mm/vma.h b/mm/vma.h
index 1f2de6cb3b97..270008e5babc 100644
--- a/mm/vma.h
+++ b/mm/vma.h
@@ -342,24 +342,23 @@ void unmap_region(struct unmap_desc *unmap);
  * @vma: The VMA containing the range @start to @end to be updated.
  * @start: The start of the range to update. May be offset within @vma.
  * @end: The exclusive end of the range to update, may be offset within @vma.
- * @vm_flags_ptr: A pointer to the VMA flags that the @start to @end range is
+ * @vma_flags_ptr: A pointer to the VMA flags that the @start to @end range is
  * about to be set to. On merge, this will be updated to include sticky flags.
  *
  * IMPORTANT: The actual modification being requested here is NOT applied,
  * rather the VMA is perhaps split, perhaps merged to accommodate the change,
  * and the caller is expected to perform the actual modification.
  *
- * In order to account for sticky VMA flags, the @vm_flags_ptr parameter points
+ * In order to account for sticky VMA flags, the @vma_flags_ptr parameter 
points
  * to the requested flags which are then updated so the caller, should they
  * overwrite any existing flags, correctly retains these.
  *
  * Returns: A VMA which contains the range @start to @end ready to have its
- * flags altered to *@vm_flags.
+ * flags altered to *@vma_flags.
  */
 __must_check struct vm_area_struct *vma_modify_flags(struct vma_iterator *vmi,
                struct vm_area_struct *prev, struct vm_area_struct *vma,
-               unsigned long start, unsigned long end,
-               vm_flags_t *vm_flags_ptr);
+               unsigned long start, unsigned long end, vma_flags_t 
*vma_flags_ptr);
 
 /**
  * vma_modify_name() - Perform any necessary split/merge in preparation for
@@ -418,7 +417,7 @@ __must_check struct vm_area_struct 
*vma_modify_policy(struct vma_iterator *vmi,
  * @vma: The VMA containing the range @start to @end to be updated.
  * @start: The start of the range to update. May be offset within @vma.
  * @end: The exclusive end of the range to update, may be offset within @vma.
- * @vm_flags: The VMA flags that the @start to @end range is about to be set 
to.
+ * @vma_flags: The VMA flags that the @start to @end range is about to be set 
to.
  * @new_ctx: The userfaultfd context that the @start to @end range is about to
  * be set to.
  * @give_up_on_oom: If an out of memory condition occurs on merge, simply give
@@ -429,11 +428,11 @@ __must_check struct vm_area_struct 
*vma_modify_policy(struct vma_iterator *vmi,
  * and the caller is expected to perform the actual modification.
  *
  * Returns: A VMA which contains the range @start to @end ready to have its VMA
- * flags changed to @vm_flags and its userfaultfd context changed to @new_ctx.
+ * flags changed to @vma_flags and its userfaultfd context changed to @new_ctx.
  */
 __must_check struct vm_area_struct *vma_modify_flags_uffd(struct vma_iterator 
*vmi,
                struct vm_area_struct *prev, struct vm_area_struct *vma,
-               unsigned long start, unsigned long end, vm_flags_t vm_flags,
+               unsigned long start, unsigned long end, const vma_flags_t 
*vma_flags,
                struct vm_userfaultfd_ctx new_ctx, bool give_up_on_oom);
 
 __must_check struct vm_area_struct *vma_merge_new_range(struct 
vma_merge_struct *vmg);
diff --git a/tools/testing/vma/tests/merge.c b/tools/testing/vma/tests/merge.c
index 44e3977e3fc0..03b6f9820e0a 100644
--- a/tools/testing/vma/tests/merge.c
+++ b/tools/testing/vma/tests/merge.c
@@ -132,7 +132,6 @@ static bool test_simple_modify(void)
        struct vm_area_struct *vma;
        vma_flags_t vma_flags = mk_vma_flags(VMA_READ_BIT, VMA_WRITE_BIT, 
VMA_MAYREAD_BIT,
                                             VMA_MAYWRITE_BIT);
-       vm_flags_t legacy_flags = VM_READ | VM_WRITE;
        struct mm_struct mm = {};
        struct vm_area_struct *init_vma = alloc_vma(&mm, 0, 0x3000, 0, 
vma_flags);
        VMA_ITERATOR(vmi, &mm, 0x1000);
@@ -144,7 +143,7 @@ static bool test_simple_modify(void)
         * performs the merge/split only.
         */
        vma = vma_modify_flags(&vmi, init_vma, init_vma,
-                              0x1000, 0x2000, &legacy_flags);
+                              0x1000, 0x2000, &vma_flags);
        ASSERT_NE(vma, NULL);
        /* We modify the provided VMA, and on split allocate new VMAs. */
        ASSERT_EQ(vma, init_vma);
-- 
2.53.0


Reply via email to