On Fri, Sep 15, 2023 at 10:59:26AM +0000, Matteo Rizzo wrote:
> From: Jann Horn <ja...@google.com>
> 
> This is refactoring in preparation for checking freeptrs for corruption
> inside freelist_ptr_decode().
> 
> Signed-off-by: Jann Horn <ja...@google.com>
> Co-developed-by: Matteo Rizzo <matteori...@google.com>
> Signed-off-by: Matteo Rizzo <matteori...@google.com>
> ---
>  mm/slub.c | 43 +++++++++++++++++++++++--------------------
>  1 file changed, 23 insertions(+), 20 deletions(-)
> 
> diff --git a/mm/slub.c b/mm/slub.c
> index eaa1256aff89..42e7cc0b4452 100644
> --- a/mm/slub.c
> +++ b/mm/slub.c
> @@ -383,7 +383,8 @@ static inline freeptr_t freelist_ptr_encode(const struct 
> kmem_cache *s,
>  }
>  
>  static inline void *freelist_ptr_decode(const struct kmem_cache *s,
> -                                     freeptr_t ptr, unsigned long ptr_addr)
> +                                     freeptr_t ptr, unsigned long ptr_addr,
> +                                     struct slab *slab)

Can't "s" be dropped in these cases, since "s" should be
"slab->kmem_cache"?

>  {
>       void *decoded;
>  
> @@ -395,7 +396,8 @@ static inline void *freelist_ptr_decode(const struct 
> kmem_cache *s,
>       return decoded;
>  }
>  
> -static inline void *get_freepointer(struct kmem_cache *s, void *object)
> +static inline void *get_freepointer(struct kmem_cache *s, void *object,
> +                                 struct slab *slab)
>  {
>       unsigned long ptr_addr;
>       freeptr_t p;
> @@ -403,7 +405,7 @@ static inline void *get_freepointer(struct kmem_cache *s, 
> void *object)
>       object = kasan_reset_tag(object);
>       ptr_addr = (unsigned long)object + s->offset;
>       p = *(freeptr_t *)(ptr_addr);
> -     return freelist_ptr_decode(s, p, ptr_addr);
> +     return freelist_ptr_decode(s, p, ptr_addr, slab);
>  }
>  
>  #ifndef CONFIG_SLUB_TINY
> @@ -424,18 +426,19 @@ static void prefetch_freepointer(const struct 
> kmem_cache *s, void *object)
>   * get_freepointer_safe() returns initialized memory.
>   */
>  __no_kmsan_checks
> -static inline void *get_freepointer_safe(struct kmem_cache *s, void *object)
> +static inline void *get_freepointer_safe(struct kmem_cache *s, void *object,
> +                                          struct slab *slab)
>  {
>       unsigned long freepointer_addr;
>       freeptr_t p;
>  
>       if (!debug_pagealloc_enabled_static())
> -             return get_freepointer(s, object);
> +             return get_freepointer(s, object, slab);
>  
>       object = kasan_reset_tag(object);
>       freepointer_addr = (unsigned long)object + s->offset;
>       copy_from_kernel_nofault(&p, (freeptr_t *)freepointer_addr, sizeof(p));
> -     return freelist_ptr_decode(s, p, freepointer_addr);
> +     return freelist_ptr_decode(s, p, freepointer_addr, slab);
>  }
>  
>  static inline void set_freepointer(struct kmem_cache *s, void *object, void 
> *fp)
> @@ -627,7 +630,7 @@ static void __fill_map(unsigned long *obj_map, struct 
> kmem_cache *s,
>  
>       bitmap_zero(obj_map, slab->objects);
>  
> -     for (p = slab->freelist; p; p = get_freepointer(s, p))
> +     for (p = slab->freelist; p; p = get_freepointer(s, p, slab))
>               set_bit(__obj_to_index(s, addr, p), obj_map);
>  }
>  
> @@ -937,7 +940,7 @@ static void print_trailer(struct kmem_cache *s, struct 
> slab *slab, u8 *p)
>       print_slab_info(slab);
>  
>       pr_err("Object 0x%p @offset=%tu fp=0x%p\n\n",
> -            p, p - addr, get_freepointer(s, p));
> +            p, p - addr, get_freepointer(s, p, slab));
>  
>       if (s->flags & SLAB_RED_ZONE)
>               print_section(KERN_ERR, "Redzone  ", p - s->red_left_pad,
> @@ -1230,7 +1233,7 @@ static int check_object(struct kmem_cache *s, struct 
> slab *slab,
>               return 1;
>  
>       /* Check free pointer validity */
> -     if (!check_valid_pointer(s, slab, get_freepointer(s, p))) {
> +     if (!check_valid_pointer(s, slab, get_freepointer(s, p, slab))) {
>               object_err(s, slab, p, "Freepointer corrupt");
>               /*
>                * No choice but to zap it and thus lose the remainder
> @@ -1298,7 +1301,7 @@ static int on_freelist(struct kmem_cache *s, struct 
> slab *slab, void *search)
>                       break;
>               }
>               object = fp;
> -             fp = get_freepointer(s, object);
> +             fp = get_freepointer(s, object, slab);
>               nr++;
>       }
>  
> @@ -1810,7 +1813,7 @@ static inline bool slab_free_freelist_hook(struct 
> kmem_cache *s,
>               object = next;
>               /* Single objects don't actually contain a freepointer */
>               if (object != old_tail)
> -                     next = get_freepointer(s, object);
> +                     next = get_freepointer(s, object, virt_to_slab(object));
>  
>               /* If object's reuse doesn't have to be delayed */
>               if (!slab_free_hook(s, object, slab_want_init_on_free(s))) {
> @@ -2161,7 +2164,7 @@ static void *alloc_single_from_partial(struct 
> kmem_cache *s,
>       lockdep_assert_held(&n->list_lock);
>  
>       object = slab->freelist;
> -     slab->freelist = get_freepointer(s, object);
> +     slab->freelist = get_freepointer(s, object, slab);
>       slab->inuse++;
>  
>       if (!alloc_debug_processing(s, slab, object, orig_size)) {
> @@ -2192,7 +2195,7 @@ static void *alloc_single_from_new_slab(struct 
> kmem_cache *s,
>  
>  
>       object = slab->freelist;
> -     slab->freelist = get_freepointer(s, object);
> +     slab->freelist = get_freepointer(s, object, slab);
>       slab->inuse = 1;
>  
>       if (!alloc_debug_processing(s, slab, object, orig_size))
> @@ -2517,7 +2520,7 @@ static void deactivate_slab(struct kmem_cache *s, 
> struct slab *slab,
>       freelist_tail = NULL;
>       freelist_iter = freelist;
>       while (freelist_iter) {
> -             nextfree = get_freepointer(s, freelist_iter);
> +             nextfree = get_freepointer(s, freelist_iter, slab);
>  
>               /*
>                * If 'nextfree' is invalid, it is possible that the object at
> @@ -2944,7 +2947,7 @@ static inline bool free_debug_processing(struct 
> kmem_cache *s,
>  
>       /* Reached end of constructed freelist yet? */
>       if (object != tail) {
> -             object = get_freepointer(s, object);
> +             object = get_freepointer(s, object, slab);
>               goto next_object;
>       }
>       checks_ok = true;
> @@ -3173,7 +3176,7 @@ static void *___slab_alloc(struct kmem_cache *s, gfp_t 
> gfpflags, int node,
>        * That slab must be frozen for per cpu allocations to work.
>        */
>       VM_BUG_ON(!c->slab->frozen);
> -     c->freelist = get_freepointer(s, freelist);
> +     c->freelist = get_freepointer(s, freelist, c->slab);
>       c->tid = next_tid(c->tid);
>       local_unlock_irqrestore(&s->cpu_slab->lock, flags);
>       return freelist;
> @@ -3275,7 +3278,7 @@ static void *___slab_alloc(struct kmem_cache *s, gfp_t 
> gfpflags, int node,
>                * For !pfmemalloc_match() case we don't load freelist so that
>                * we don't make further mismatched allocations easier.
>                */
> -             deactivate_slab(s, slab, get_freepointer(s, freelist));
> +             deactivate_slab(s, slab, get_freepointer(s, freelist, slab));
>               return freelist;
>       }
>  
> @@ -3377,7 +3380,7 @@ static __always_inline void *__slab_alloc_node(struct 
> kmem_cache *s,
>           unlikely(!object || !slab || !node_match(slab, node))) {
>               object = __slab_alloc(s, gfpflags, node, addr, c, orig_size);
>       } else {
> -             void *next_object = get_freepointer_safe(s, object);
> +             void *next_object = get_freepointer_safe(s, object, slab);
>  
>               /*
>                * The cmpxchg will only match if there was no additional
> @@ -3984,7 +3987,7 @@ static inline int __kmem_cache_alloc_bulk(struct 
> kmem_cache *s, gfp_t flags,
>  
>                       continue; /* goto for-loop */
>               }
> -             c->freelist = get_freepointer(s, object);
> +             c->freelist = get_freepointer(s, object, c->slab);
>               p[i] = object;
>               maybe_wipe_obj_freeptr(s, p[i]);
>       }
> @@ -4275,7 +4278,7 @@ static void early_kmem_cache_node_alloc(int node)
>       init_tracking(kmem_cache_node, n);
>  #endif
>       n = kasan_slab_alloc(kmem_cache_node, n, GFP_KERNEL, false);
> -     slab->freelist = get_freepointer(kmem_cache_node, n);
> +     slab->freelist = get_freepointer(kmem_cache_node, n, slab);
>       slab->inuse = 1;
>       kmem_cache_node->node[node] = n;
>       init_kmem_cache_node(n);
> -- 
> 2.42.0.459.ge4e396fd5e-goog
> 

-- 
Kees Cook

Reply via email to