On Sun, Jan 31, 2016 at 01:32:47PM +1100, Matthew Wilcox wrote:
> On Fri, Jan 29, 2016 at 10:01:13PM -0800, Dan Williams wrote:
> > On Fri, Jan 29, 2016 at 9:28 PM, Matthew Wilcox <wi...@linux.intel.com> 
> > wrote:
> > > If we store the PFN of the underlying page instead, we don't have this
> > > problem.  Instead, we have a different problem; of the device going
> > > away under us.  I'm trying to find the code which tears down PTEs when
> > > the device goes away, and I'm not seeing it.  What do we do about user
> > > mappings of the device?
> > 
> > I deferred the dax tear down code until next cycle as Al rightly
> > pointed out some needed re-works:
> > 
> > https://lists.01.org/pipermail/linux-nvdimm/2016-January/003995.html
> 
> Thanks; I eventually found it in my email somewhere over the Pacific.
> 
> I did probably 70% of the work needed to switch the radix tree over to
> storing PFNs instead of sectors.  It seems viable, though it's a big
> change from where we are today:

70%?!  Hah.  I'd done maybe 50%.  This isn't everything needed; I still
need to write radix_tree_replace().  But it's enough to get a flavour for
where this line of thinking takes us.  I think it ends up being cleaner
code, and possibly better performing.  I also think it points us back
in the direction of wanting an address_space operation to return a PFN
for the radix tree instead of handling buffer_heads directly in dax.c.

Ah well.  Time to sleep ...

>From 0321c30eeb189ad2da8dcc25623419e2ba9c6cee Mon Sep 17 00:00:00 2001
From: Matthew Wilcox <matthew.r.wil...@intel.com>
Date: Sun, 31 Jan 2016 13:38:21 +1100
Subject: [PATCH] Giant non-compiling mess

Note that clear_pmem needs to be updated to set the needs_wmb() flag.

Signed-off-by: Matthew Wilcox <matthew.r.wil...@intel.com>
---
 fs/dax.c                   | 1127 +++++++++++++++++++++-----------------------
 include/linux/dax.h        |    3 +-
 include/linux/pfn_t.h      |   41 +-
 include/linux/radix-tree.h |    9 -
 include/linux/sched.h      |    1 +
 5 files changed, 565 insertions(+), 616 deletions(-)

diff --git a/fs/dax.c b/fs/dax.c
index e9701d6..38b92b5 100644
--- a/fs/dax.c
+++ b/fs/dax.c
@@ -25,12 +25,132 @@
 #include <linux/mm.h>
 #include <linux/mutex.h>
 #include <linux/pagevec.h>
+#include <linux/pfn_t.h>
 #include <linux/pmem.h>
+#include <linux/preempt.h>
 #include <linux/sched.h>
+#include <linux/sizes.h>
 #include <linux/uio.h>
 #include <linux/vmstat.h>
-#include <linux/pfn_t.h>
-#include <linux/sizes.h>
+
+/*
+ * 32-bit architectures want to override this to actually map/unmap
+ * their persistent memory.  ARM, SPARC & MIPS also want to override it
+ * to map the PFN at an address that uses the same cachelines as the
+ * userspace mapping (that's what 'index' is for)
+ */
+static void *dax_map_pfn(pfn_t pfn, unsigned long index)
+{
+       if (is_bad_pfn_t(pfn))
+               return NULL;
+       preempt_disable();
+       pagefault_disable();
+       return pfn_to_kaddr(pfn_t_to_pfn(pfn));
+}
+
+static void dax_unmap_pfn(void *addr)
+{
+       pagefault_enable();
+       preempt_enable();
+}
+
+/*
+ * DAX uses the 'exceptional' entries to store PFNs in the radix tree.
+ * Bit 0 is clear (the radix tree uses this for its own purposes).  Bit
+ * 1 is set (to indicate an exceptional entry).  Bits 2 & 3 are PFN_DEV
+ * and PFN_MAP.  The top two bits denote the size of the entry (PTE, PMD,
+ * PUD, one reserved).  That leaves us 26 bits on 32-bit systems and 58
+ * bits on 64-bit systems, able to address 256GB and 1024EB respectively.
+ */
+#define RADIX_DAX_SIZE_MASK    (0x3UL << (BITS_PER_LONG - 2))
+#define RADIX_TREE_MASK                (RADIX_TREE_INDIRECT_PTR | \
+                                RADIX_TREE_EXCEPTIONAL_ENTRY)
+#define RADIX_DAX_PFN_MASK     (~(RADIX_DAX_SIZE_MASK | RADIX_TREE_MASK))
+#define RADIX_DAX_SHIFT                4
+#define RADIX_DAX_PTE          (0x0UL << (BITS_PER_LONG - 2))
+#define RADIX_DAX_PMD          (0x1UL << (BITS_PER_LONG - 2))
+#define RADIX_DAX_PUD          (0x2UL << (BITS_PER_LONG - 2))
+#define RADIX_DAX_SIZE(entry)  ((unsigned long)entry & RADIX_DAX_SIZE_MASK)
+#define RADIX_DAX_ENTRY(pfn, size) \
+       ((void *)((pfn_t_to_pfn(pfn) << RADIX_DAX_SHIFT) | size))
+
+/* The 'colour' (ie low bits) within a PMD/PUD of a page offset. */
+#define PG_PMD_COLOUR  ((PMD_SIZE >> PAGE_CACHE_SHIFT) - 1)
+#define PG_PUD_COLOUR  ((PUD_SIZE >> PAGE_CACHE_SHIFT) - 1)
+
+static pfn_t radix_to_pfn_t(void *entry, pgoff_t index)
+{
+       pfn_t pfn = { .val = (unsigned long)entry & RADIX_DAX_PFN_MASK };
+       unsigned offset = 0;
+
+       if (RADIX_DAX_SIZE(entry) == RADIX_DAX_PMD)
+               offset = index & PG_PMD_COLOUR;
+       else if (RADIX_DAX_SIZE(entry) == RADIX_DAX_PUD)
+               offset = index & PG_PUD_COLOUR;
+
+       return pfn_t_add(pfn, offset);
+}
+
+static void *pfn_to_radix(pfn_t pfn, unsigned long size)
+{
+       unsigned long value = pfn.val;
+       BUG_ON(value & RADIX_DAX_PFN_MASK);
+       return (void *)(value | size);
+}
+
+static unsigned size_to_order(unsigned long size)
+{
+       switch (size) {
+       case RADIX_DAX_PTE: return 0;
+       case RADIX_DAX_PMD: return PMD_SHIFT - PAGE_SHIFT;
+       case RADIX_DAX_PUD: return PUD_SHIFT - PAGE_SHIFT;
+       }
+       BUG();
+}
+
+static unsigned size_to_bytes(unsigned long size)
+{
+       switch (size) {
+       case RADIX_DAX_PTE: return PAGE_CACHE_SIZE;
+       case RADIX_DAX_PMD: return PMD_SIZE;
+       case RADIX_DAX_PUD: return PUD_SIZE;
+       }
+       BUG();
+}
+
+static int dax_add_radix_entry(struct address_space *mapping, pgoff_t index,
+                               pfn_t pfn, unsigned long size, bool dirty)
+{
+       struct radix_tree_root *page_tree = &mapping->page_tree;
+       int count = 0;
+       void *entry;
+       unsigned order = size_to_order(size);
+
+       if (dirty)
+               __mark_inode_dirty(mapping->host, I_DIRTY_PAGES);
+
+       spin_lock_irq(&mapping->tree_lock);
+       entry = radix_tree_lookup(page_tree, index);
+       if (!radix_tree_exceptional_entry(entry)) {
+               count = -EEXIST;
+               goto unlock;
+       } else if (entry) {
+               if (size <= RADIX_DAX_SIZE(entry))
+                       goto dirty;
+       }
+       count = radix_tree_replace(page_tree, index, order,
+                                       pfn_to_radix(pfn, size));
+       if (count < 0)
+               goto unlock;
+
+       mapping->nrexceptional -= (count - 1);
+ dirty:
+       if (dirty)
+               radix_tree_tag_set(page_tree, index, PAGECACHE_TAG_DIRTY);
+ unlock:
+       spin_unlock_irq(&mapping->tree_lock);
+       return count;
+}
 
 static long dax_map_atomic(struct block_device *bdev, struct blk_dax_ctl *dax)
 {
@@ -58,17 +178,235 @@ static void dax_unmap_atomic(struct block_device *bdev,
        blk_queue_exit(bdev->bd_queue);
 }
 
+static sector_t to_sector(const struct buffer_head *bh,
+               const struct inode *inode)
+{
+       sector_t sector = bh->b_blocknr << (inode->i_blkbits - 9);
+
+       return sector;
+}
+
+static bool buffer_written(struct buffer_head *bh)
+{
+       return buffer_mapped(bh) && !buffer_unwritten(bh);
+}
+
+static int dax_replace_hole(struct address_space *mapping, pgoff_t index,
+                               unsigned long size, pfn_t pfn)
+{
+       unsigned order = size_to_order(size);
+       int i, error;
+
+       for (i = 0; i < (1 << order); i++) {
+               struct page *page;
+ repeat:
+               page = find_get_entry(mapping, index + i);
+               if (!page || radix_tree_exceptional_entry(page))
+                       continue;
+
+               lock_page(page);
+               if (unlikely(page->mapping != mapping)) {
+                       unlock_page(page);
+                       page_cache_release(page);
+                       goto repeat;
+               }
+
+               delete_from_page_cache(page);
+               unlock_page(page);
+               page_cache_release(page);
+       }
+
+       /*
+        * Somebody else could look in the radix tree and find nothing.
+        * It's harmless though; they'll find the correct pfn by calling
+        * the filesystem.
+        */
+       error = dax_add_radix_entry(mapping, index, pfn, size, true);
+
+       unmap_mapping_range(mapping, index << PAGE_CACHE_SHIFT,
+                               PAGE_CACHE_SIZE, 0);
+
+       return error;
+}
+
+static int dax_add_pfn_sized(struct address_space *mapping, pgoff_t index,
+                               size_t size, bool write, pfn_t pfn,
+                               unsigned long radix_size, unsigned entry_size)
+{
+       int error;
+       bool report = true;
+
+       while (size >= entry_size) {
+               error = dax_add_radix_entry(mapping, index, pfn,
+                                               radix_size, write);
+               if (error == -EEXIST)
+                       error = dax_replace_hole(mapping, index, radix_size,
+                                               pfn);
+               if (error)
+                       break;
+               report = false;
+
+               size -= entry_size;
+               pfn = pfn_t_add(pfn, entry_size / PAGE_CACHE_SIZE);
+               index += entry_size / PAGE_CACHE_SIZE;
+       }
+
+       return report ? error : 0;
+}
+
+static int dax_add_pfn_entries(struct address_space *mapping, pgoff_t index,
+                               size_t size, bool write, pfn_t pfn)
+{
+       int error = 0;
+       int called = 0;
+       size_t max;
+
+       max = (PG_PMD_COLOUR + 1 - index) << PAGE_CACHE_SHIFT;
+       if (index & PG_PMD_COLOUR) {
+               error = dax_add_pfn_sized(mapping, index, min(size, max),
+                               write, pfn, RADIX_DAX_PTE, PAGE_CACHE_SIZE);
+               called++;
+       }
+       size -= min(size, max);
+       if (error || !size)
+               goto out;
+       index += max >> PAGE_CACHE_SHIFT; 
+
+       max = (PG_PUD_COLOUR + 1 - index) << PAGE_CACHE_SHIFT;
+       if (index & PG_PUD_COLOUR) {
+               error = dax_add_pfn_sized(mapping, index, min(size, max),
+                               write, pfn, RADIX_DAX_PMD, PMD_SIZE);
+               called++;
+       }
+       size -= min(size, max);
+       if (error || !size)
+               goto out;
+       index += max >> PMD_SHIFT; 
+
+       error = dax_add_pfn_sized(mapping, index, size,
+                               write, pfn, RADIX_DAX_PUD, PUD_SIZE);
+       called++;
+       index += size >> PUD_SHIFT;
+       size = size & ~PMD_MASK;
+       if (error || !size)
+               goto out;
+
+       error = dax_add_pfn_sized(mapping, index, size,
+                               write, pfn, RADIX_DAX_PMD, PMD_SIZE);
+       index += size >> PMD_SHIFT;
+       size = size & ~PAGE_CACHE_MASK;
+       if (error || !size)
+               return 0;
+
+       error = dax_add_pfn_sized(mapping, index, size,
+                               write, pfn, RADIX_DAX_PTE, PAGE_CACHE_SIZE);
+ out:
+       if (called > 1)
+               error = 0;
+       return error;
+}
+
+/*
+ * Populate the page cache with as many pfns as the filesystem is willing
+ * to tell us about from a single call to get_block, starting at @index and
+ * continuing up to @max bytes.
+ */
+static int dax_create_pfns(struct address_space *mapping, pgoff_t index,
+                               unsigned max, bool write, pfn_t *pfn,
+                               get_block_t get_block, struct buffer_head *bh)
+{
+       struct inode *inode = mapping->host;
+       unsigned blkbits = inode->i_blkbits;
+       sector_t block = index << (PAGE_CACHE_SHIFT - blkbits);
+       struct blk_dax_ctl dax;
+       int error, result = 0;
+
+       bh->b_size = max;
+       bh->b_state = 0;
+       error = get_block(inode, block, bh, write);
+       if (error)
+               goto error;
+
+       if (!buffer_written(bh))
+               goto hole;
+
+       dax.sector = to_sector(bh, inode);
+       dax.size = bh->b_size;
+       error = dax_map_atomic(bh->b_bdev, &dax);
+       if (error < 0)
+               goto error;
+
+       /*
+        * We may be about to write data to it, but now it's allocated,
+        * and another thread will be able to find it in the page cache,
+        * so we have to zero it otherwise there's a write vs fault race
+        * that could expose stale data to an application.
+        */
+       if (buffer_unwritten(bh) || buffer_new(bh)) {
+               clear_pmem(dax.addr, bh->b_size);
+               result = 1;
+       }
+
+       dax_unmap_atomic(bh->b_bdev, &dax);
+
+       error = dax_add_pfn_entries(mapping, index, bh->b_size,
+                                       write, dax.pfn);
+
+       /*
+        * Even if we had an error adding the PFN to the radix tree,
+        * the PFN is still good, so return it.
+        */
+       *pfn = dax.pfn;
+       return error ? error : result;
+
+ hole:
+ error:
+       *pfn = bad_pfn_t;
+       return error;
+}
+
+/*
+ * Returns either a negative errno, 0 if no allocation had to be performed,
+ * or 1 if the filesystem allocated a block.
+ */
+static int dax_get_pfn(struct address_space *mapping, pgoff_t index,
+                               size_t len, bool write, pfn_t *pfn,
+                               get_block_t get_block, struct buffer_head *bh)
+{
+       void *entry;
+
+       rcu_read_lock();
+       entry = radix_tree_lookup(&mapping->page_tree, index);
+       rcu_read_unlock();
+
+       if (radix_tree_exceptional_entry(entry)) {
+               *pfn = radix_to_pfn_t(entry, index);
+               return 0;
+       }
+
+       if (entry) {
+               if (write)
+                       return dax_create_pfns(mapping, index, len, true, pfn,
+                                               get_block, bh);
+       } else {
+               return dax_create_pfns(mapping, index, len, write, pfn,
+                                               get_block, bh);
+       }
+
+       *pfn = bad_pfn_t;
+       return 0;
+}
+
 /*
  * dax_clear_blocks() is called from within transaction context from XFS,
  * and hence this means the stack from this point must follow GFP_NOFS
  * semantics for all operations.
  */
-int dax_clear_blocks(struct inode *inode, sector_t block, long _size)
+int dax_clear_blocks(struct block_device *bdev, sector_t sector, long size)
 {
-       struct block_device *bdev = inode->i_sb->s_bdev;
        struct blk_dax_ctl dax = {
-               .sector = block << (inode->i_blkbits - 9),
-               .size = _size,
+               .sector = sector,
+               .size = size,
        };
 
        might_sleep();
@@ -91,133 +429,52 @@ int dax_clear_blocks(struct inode *inode, sector_t block, 
long _size)
 }
 EXPORT_SYMBOL_GPL(dax_clear_blocks);
 
-/* the clear_pmem() calls are ordered by a wmb_pmem() in the caller */
-static void dax_new_buf(void __pmem *addr, unsigned size, unsigned first,
-               loff_t pos, loff_t end)
-{
-       loff_t final = end - pos + first; /* The final byte of the buffer */
-
-       if (first > 0)
-               clear_pmem(addr, first);
-       if (final < size)
-               clear_pmem(addr + final, size - final);
-}
-
-static bool buffer_written(struct buffer_head *bh)
-{
-       return buffer_mapped(bh) && !buffer_unwritten(bh);
-}
-
-/*
- * When ext4 encounters a hole, it returns without modifying the buffer_head
- * which means that we can't trust b_size.  To cope with this, we set b_state
- * to 0 before calling get_block and, if any bit is set, we know we can trust
- * b_size.  Unfortunate, really, since ext4 knows precisely how long a hole is
- * and would save us time calling get_block repeatedly.
- */
-static bool buffer_size_valid(struct buffer_head *bh)
-{
-       return bh->b_state != 0;
-}
-
-
-static sector_t to_sector(const struct buffer_head *bh,
-               const struct inode *inode)
-{
-       sector_t sector = bh->b_blocknr << (inode->i_blkbits - 9);
-
-       return sector;
-}
-
 static ssize_t dax_io(struct inode *inode, struct iov_iter *iter,
-                     loff_t start, loff_t end, get_block_t get_block,
-                     struct buffer_head *bh)
+                               loff_t start, loff_t end,
+                               get_block_t get_block, struct buffer_head *bh)
 {
-       loff_t pos = start, max = start, bh_max = start;
-       bool hole = false, need_wmb = false;
-       struct block_device *bdev = NULL;
-       int rw = iov_iter_rw(iter), rc;
-       long map_len = 0;
-       struct blk_dax_ctl dax = {
-               .addr = (void __pmem *) ERR_PTR(-EIO),
-       };
+       loff_t pos = start;
+       int error = 0;
+       const int rw = iov_iter_rw(iter);
 
        if (rw == READ)
                end = min(end, i_size_read(inode));
 
-       while (pos < end) {
-               size_t len;
-               if (pos == max) {
-                       unsigned blkbits = inode->i_blkbits;
-                       long page = pos >> PAGE_SHIFT;
-                       sector_t block = page << (PAGE_SHIFT - blkbits);
-                       unsigned first = pos - (block << blkbits);
-                       long size;
-
-                       if (pos == bh_max) {
-                               bh->b_size = PAGE_ALIGN(end - pos);
-                               bh->b_state = 0;
-                               rc = get_block(inode, block, bh, rw == WRITE);
-                               if (rc)
-                                       break;
-                               if (!buffer_size_valid(bh))
-                                       bh->b_size = 1 << blkbits;
-                               bh_max = pos - first + bh->b_size;
-                               bdev = bh->b_bdev;
-                       } else {
-                               unsigned done = bh->b_size -
-                                               (bh_max - (pos - first));
-                               bh->b_blocknr += done >> blkbits;
-                               bh->b_size -= done;
-                       }
+       while (!error && pos < end) {
+               pgoff_t pgoff = pos >> PAGE_CACHE_SHIFT;
+               unsigned off = pos & ~PAGE_CACHE_MASK;
+               size_t len = end - pos;
+               pfn_t pfn;
+               void __pmem *addr;
 
-                       hole = rw == READ && !buffer_written(bh);
-                       if (hole) {
-                               size = bh->b_size - first;
-                       } else {
-                               dax_unmap_atomic(bdev, &dax);
-                               dax.sector = to_sector(bh, inode);
-                               dax.size = bh->b_size;
-                               map_len = dax_map_atomic(bdev, &dax);
-                               if (map_len < 0) {
-                                       rc = map_len;
-                                       break;
-                               }
-                               if (buffer_unwritten(bh) || buffer_new(bh)) {
-                                       dax_new_buf(dax.addr, map_len, first,
-                                                       pos, end);
-                                       need_wmb = true;
-                               }
-                               dax.addr += first;
-                               size = map_len - first;
-                       }
-                       max = min(pos + size, end);
-               }
+               error = dax_get_pfn(inode->i_mapping, pgoff, len, rw == WRITE,
+                                               &pfn, get_block, bh);
+               if (error < 0)
+                       break;
+               addr = dax_map_pfn(pfn, pgoff) + off;
 
-               if (iov_iter_rw(iter) == WRITE) {
-                       len = copy_from_iter_pmem(dax.addr, max - pos, iter);
-                       need_wmb = true;
-               } else if (!hole)
-                       len = copy_to_iter((void __force *) dax.addr, max - pos,
-                                       iter);
+               if (len > PAGE_CACHE_SIZE)
+                       len = PAGE_CACHE_SIZE;
+
+               if (rw == WRITE) {
+                       len = copy_from_iter_pmem(addr, len, iter);
+                       current->needs_wmb = true;
+               } else if (addr)
+                       len = copy_to_iter((void __force *) addr, len, iter);
                else
-                       len = iov_iter_zero(max - pos, iter);
+                       len = iov_iter_zero(len, iter);
+               dax_unmap_pfn(addr - off);
 
-               if (!len) {
-                       rc = -EFAULT;
-                       break;
-               }
+               if (!len)
+                       error = -EFAULT;
 
                pos += len;
-               if (!IS_ERR(dax.addr))
-                       dax.addr += len;
        }
 
-       if (need_wmb)
+       if (current->needs_wmb)
                wmb_pmem();
-       dax_unmap_atomic(bdev, &dax);
 
-       return (pos == start) ? rc : pos - start;
+       return (pos == start) ? error : pos - start;
 }
 
 /**
@@ -238,15 +495,14 @@ static ssize_t dax_io(struct inode *inode, struct 
iov_iter *iter,
  * is in progress.
  */
 ssize_t dax_do_io(struct kiocb *iocb, struct inode *inode,
-                 struct iov_iter *iter, loff_t pos, get_block_t get_block,
-                 dio_iodone_t end_io, int flags)
+                       struct iov_iter *iter, loff_t pos,
+                       get_block_t get_block, dio_iodone_t end_io, int flags)
 {
        struct buffer_head bh;
        ssize_t retval = -EINVAL;
        loff_t end = pos + iov_iter_count(iter);
 
        memset(&bh, 0, sizeof(bh));
-       bh.b_bdev = inode->i_sb->s_bdev;
 
        if ((flags & DIO_LOCKING) && iov_iter_rw(iter) == READ) {
                struct address_space *mapping = inode->i_mapping;
@@ -277,124 +533,26 @@ ssize_t dax_do_io(struct kiocb *iocb, struct inode 
*inode,
 }
 EXPORT_SYMBOL_GPL(dax_do_io);
 
-/*
- * The user has performed a load from a hole in the file.  Allocating
- * a new page in the file would cause excessive storage usage for
- * workloads with sparse files.  We allocate a page cache page instead.
- * We'll kick it out of the page cache if it's ever written to,
- * otherwise it will simply fall out of the page cache under memory
- * pressure without ever having been dirtied.
- */
-static int dax_load_hole(struct address_space *mapping, struct page *page,
-                                                       struct vm_fault *vmf)
+static int copy_user_pfn(struct vm_fault *vmf, pfn_t pfn)
 {
-       if (!page)
-               page = find_or_create_page(mapping, vmf->pgoff,
-                                               vmf->gfp_mask | __GFP_ZERO);
-       if (!page)
-               return VM_FAULT_OOM;
-       vmf->page = page;
-       return VM_FAULT_LOCKED;
-}
+       void *vto, *vfrom;
 
-static int copy_user_bh(struct page *to, struct inode *inode,
-               struct buffer_head *bh, unsigned long vaddr)
-{
-       struct blk_dax_ctl dax = {
-               .sector = to_sector(bh, inode),
-               .size = bh->b_size,
-       };
-       struct block_device *bdev = bh->b_bdev;
-       void *vto;
-
-       if (dax_map_atomic(bdev, &dax) < 0)
-               return PTR_ERR(dax.addr);
-       vto = kmap_atomic(to);
-       copy_user_page(vto, (void __force *)dax.addr, vaddr, to);
+       vfrom = dax_map_pfn(pfn, vmf->pgoff);
+       vto = kmap_atomic(vmf->cow_page);
+       copy_user_page(vto, vfrom, (unsigned long)vmf->virtual_address,
+                       vmf->cow_page);
        kunmap_atomic(vto);
-       dax_unmap_atomic(bdev, &dax);
+       dax_unmap_pfn(vfrom);
        return 0;
 }
 
-#define NO_SECTOR -1
-#define DAX_PMD_INDEX(page_index) (page_index & (PMD_MASK >> PAGE_CACHE_SHIFT))
-
-static int dax_radix_entry(struct address_space *mapping, pgoff_t index,
-               sector_t sector, bool pmd_entry, bool dirty)
-{
-       struct radix_tree_root *page_tree = &mapping->page_tree;
-       pgoff_t pmd_index = DAX_PMD_INDEX(index);
-       int type, error = 0;
-       void *entry;
-
-       WARN_ON_ONCE(pmd_entry && !dirty);
-       __mark_inode_dirty(mapping->host, I_DIRTY_PAGES);
-
-       spin_lock_irq(&mapping->tree_lock);
-
-       entry = radix_tree_lookup(page_tree, pmd_index);
-       if (entry && RADIX_DAX_TYPE(entry) == RADIX_DAX_PMD) {
-               index = pmd_index;
-               goto dirty;
-       }
-
-       entry = radix_tree_lookup(page_tree, index);
-       if (entry) {
-               type = RADIX_DAX_TYPE(entry);
-               if (WARN_ON_ONCE(type != RADIX_DAX_PTE &&
-                                       type != RADIX_DAX_PMD)) {
-                       error = -EIO;
-                       goto unlock;
-               }
-
-               if (!pmd_entry || type == RADIX_DAX_PMD)
-                       goto dirty;
-
-               /*
-                * We only insert dirty PMD entries into the radix tree.  This
-                * means we don't need to worry about removing a dirty PTE
-                * entry and inserting a clean PMD entry, thus reducing the
-                * range we would flush with a follow-up fsync/msync call.
-                */
-               radix_tree_delete(&mapping->page_tree, index);
-               mapping->nrexceptional--;
-       }
-
-       if (sector == NO_SECTOR) {
-               /*
-                * This can happen during correct operation if our pfn_mkwrite
-                * fault raced against a hole punch operation.  If this
-                * happens the pte that was hole punched will have been
-                * unmapped and the radix tree entry will have been removed by
-                * the time we are called, but the call will still happen.  We
-                * will return all the way up to wp_pfn_shared(), where the
-                * pte_same() check will fail, eventually causing page fault
-                * to be retried by the CPU.
-                */
-               goto unlock;
-       }
-
-       error = radix_tree_insert(page_tree, index,
-                       RADIX_DAX_ENTRY(sector, pmd_entry));
-       if (error)
-               goto unlock;
-
-       mapping->nrexceptional++;
- dirty:
-       if (dirty)
-               radix_tree_tag_set(page_tree, index, PAGECACHE_TAG_DIRTY);
- unlock:
-       spin_unlock_irq(&mapping->tree_lock);
-       return error;
-}
-
 static int dax_writeback_one(struct block_device *bdev,
                struct address_space *mapping, pgoff_t index, void *entry)
 {
        struct radix_tree_root *page_tree = &mapping->page_tree;
-       int type = RADIX_DAX_TYPE(entry);
+       unsigned size = RADIX_DAX_SIZE(entry);
        struct radix_tree_node *node;
-       struct blk_dax_ctl dax;
+       void __pmem *addr;
        void **slot;
        int ret = 0;
 
@@ -412,38 +570,14 @@ static int dax_writeback_one(struct block_device *bdev,
        /* another fsync thread may have already written back this entry */
        if (!radix_tree_tag_get(page_tree, index, PAGECACHE_TAG_TOWRITE))
                goto unlock;
-
-       if (WARN_ON_ONCE(type != RADIX_DAX_PTE && type != RADIX_DAX_PMD)) {
-               ret = -EIO;
-               goto unlock;
-       }
-
-       dax.sector = RADIX_DAX_SECTOR(entry);
-       dax.size = (type == RADIX_DAX_PMD ? PMD_SIZE : PAGE_SIZE);
        spin_unlock_irq(&mapping->tree_lock);
 
-       /*
-        * We cannot hold tree_lock while calling dax_map_atomic() because it
-        * eventually calls cond_resched().
-        */
-       ret = dax_map_atomic(bdev, &dax);
-       if (ret < 0)
-               return ret;
-
-       if (WARN_ON_ONCE(ret < dax.size)) {
-               ret = -EIO;
-               goto unmap;
-       }
-
-       wb_cache_pmem(dax.addr, dax.size);
+       addr = dax_map_pfn(radix_to_pfn_t(entry, index), index);
+       wb_cache_pmem(addr, size_to_bytes(size));
+       dax_unmap_pfn(addr);
 
        spin_lock_irq(&mapping->tree_lock);
        radix_tree_tag_clear(page_tree, index, PAGECACHE_TAG_TOWRITE);
-       spin_unlock_irq(&mapping->tree_lock);
- unmap:
-       dax_unmap_atomic(bdev, &dax);
-       return ret;
-
  unlock:
        spin_unlock_irq(&mapping->tree_lock);
        return ret;
@@ -459,27 +593,17 @@ int dax_writeback_mapping_range(struct address_space 
*mapping, loff_t start,
 {
        struct inode *inode = mapping->host;
        struct block_device *bdev = inode->i_sb->s_bdev;
-       pgoff_t start_index, end_index, pmd_index;
+       pgoff_t start_index, end_index;
        pgoff_t indices[PAGEVEC_SIZE];
        struct pagevec pvec;
        bool done = false;
        int i, ret = 0;
-       void *entry;
 
        if (WARN_ON_ONCE(inode->i_blkbits != PAGE_SHIFT))
                return -EIO;
 
        start_index = start >> PAGE_CACHE_SHIFT;
        end_index = end >> PAGE_CACHE_SHIFT;
-       pmd_index = DAX_PMD_INDEX(start_index);
-
-       rcu_read_lock();
-       entry = radix_tree_lookup(&mapping->page_tree, pmd_index);
-       rcu_read_unlock();
-
-       /* see if the start of our range is covered by a PMD entry */
-       if (entry && RADIX_DAX_TYPE(entry) == RADIX_DAX_PMD)
-               start_index = pmd_index;
 
        tag_pages_for_writeback(mapping, start_index, end_index);
 
@@ -509,107 +633,80 @@ int dax_writeback_mapping_range(struct address_space 
*mapping, loff_t start,
 }
 EXPORT_SYMBOL_GPL(dax_writeback_mapping_range);
 
-static int dax_insert_mapping(struct inode *inode, struct buffer_head *bh,
-                       struct vm_area_struct *vma, struct vm_fault *vmf)
-{
-       unsigned long vaddr = (unsigned long)vmf->virtual_address;
-       struct address_space *mapping = inode->i_mapping;
-       struct block_device *bdev = bh->b_bdev;
-       struct blk_dax_ctl dax = {
-               .sector = to_sector(bh, inode),
-               .size = bh->b_size,
-       };
-       int error;
-
-       if (dax_map_atomic(bdev, &dax) < 0) {
-               error = PTR_ERR(dax.addr);
-               goto out;
-       }
-
-       if (buffer_unwritten(bh) || buffer_new(bh)) {
-               clear_pmem(dax.addr, PAGE_SIZE);
-               wmb_pmem();
-       }
-       dax_unmap_atomic(bdev, &dax);
-
-       error = dax_radix_entry(mapping, vmf->pgoff, dax.sector, false,
-                       vmf->flags & FAULT_FLAG_WRITE);
-       if (error)
-               goto out;
-
-       error = vm_insert_mixed(vma, vaddr, dax.pfn);
-
- out:
-       return error;
-}
-
 static int dax_pte_fault(struct vm_area_struct *vma, struct vm_fault *vmf,
                        get_block_t get_block, dax_iodone_t complete_unwritten)
 {
-       struct file *file = vma->vm_file;
-       struct address_space *mapping = file->f_mapping;
+       struct address_space *mapping = vma->vm_file->f_mapping;
        struct inode *inode = mapping->host;
        struct page *page;
+       pfn_t pfn;
        struct buffer_head bh;
        unsigned long vaddr = (unsigned long)vmf->virtual_address;
-       unsigned blkbits = inode->i_blkbits;
-       sector_t block;
        pgoff_t size;
        int error;
        int major = 0;
+       bool write = vmf->flags & FAULT_FLAG_WRITE;
 
        size = (i_size_read(inode) + PAGE_CACHE_SIZE - 1) >> PAGE_CACHE_SHIFT;
        if (vmf->pgoff >= size)
                return VM_FAULT_SIGBUS;
 
        memset(&bh, 0, sizeof(bh));
-       block = (sector_t)vmf->pgoff << (PAGE_CACHE_SHIFT - blkbits);
-       bh.b_bdev = inode->i_sb->s_bdev;
-       bh.b_size = PAGE_CACHE_SIZE;
 
  repeat:
-       page = find_get_page(mapping, vmf->pgoff);
-       if (page) {
+       page = find_get_entry(mapping, vmf->pgoff);
+       if (radix_tree_exceptional_entry(page)) {
+               pfn = radix_to_pfn_t(page, vmf->pgoff);
+               page = NULL;
+       } else if (!page) {
+               error = dax_create_pfns(mapping, vmf->pgoff, PAGE_CACHE_SIZE,
+                               write && !vmf->cow_page, &pfn, get_block, &bh);
+               if (error < 0)
+                       goto out;
+
+               if (error) {
+                       count_vm_event(PGMAJFAULT);
+                       mem_cgroup_count_vm_event(vma->vm_mm, PGMAJFAULT);
+                       major = VM_FAULT_MAJOR;
+                       error = 0;
+               }
+       } else {
                if (!lock_page_or_retry(page, vma->vm_mm, vmf->flags)) {
                        page_cache_release(page);
                        return VM_FAULT_RETRY;
-               }
-               if (unlikely(page->mapping != mapping)) {
+               } else if (unlikely(page->mapping != mapping)) {
                        unlock_page(page);
                        page_cache_release(page);
                        goto repeat;
                }
        }
 
-       error = get_block(inode, block, &bh, 0);
-       if (!error && (bh.b_size < PAGE_CACHE_SIZE))
-               error = -EIO;           /* fs corruption? */
-       if (error)
-               goto unlock_page;
+       if (is_bad_pfn_t(pfn) && !vmf->cow_page) {
+               /*
+                * Allocating a new page in the file would cause excessive
+                * storage usage for workloads with sparse files.  We allocate
+                * a page cache page instead.  We'll kick it out of the page
+                * cache if it's ever written to, otherwise it will simply
+                * fall out of the page cache under memory pressure without
+                * ever having been dirtied.
+                */
+               if (!page)
+                       page = find_or_create_page(mapping, vmf->pgoff,
+                                               vmf->gfp_mask | __GFP_ZERO);
+               if (!page)
+                       return VM_FAULT_OOM;
+               vmf->page = page;
+               return VM_FAULT_LOCKED;
+       }
 
-       if (!buffer_mapped(&bh) && !buffer_unwritten(&bh) && !vmf->cow_page) {
-               if (vmf->flags & FAULT_FLAG_WRITE) {
-                       error = get_block(inode, block, &bh, 1);
-                       count_vm_event(PGMAJFAULT);
-                       mem_cgroup_count_vm_event(vma->vm_mm, PGMAJFAULT);
-                       major = VM_FAULT_MAJOR;
-                       if (!error && (bh.b_size < PAGE_CACHE_SIZE))
-                               error = -EIO;
+       if (vmf->cow_page) {
+               if (is_bad_pfn_t(pfn)) {
+                       clear_user_highpage(vmf->cow_page, vaddr);
+               } else {
+                       error = copy_user_pfn(vmf, pfn);
                        if (error)
                                goto unlock_page;
-               } else {
-                       return dax_load_hole(mapping, page, vmf);
                }
-       }
-
-       if (vmf->cow_page) {
-               struct page *new_page = vmf->cow_page;
-               if (buffer_written(&bh))
-                       error = copy_user_bh(new_page, inode, &bh, vaddr);
-               else
-                       clear_user_highpage(new_page, vaddr);
-               if (error)
-                       goto unlock_page;
                vmf->page = page;
 
                /*
@@ -625,18 +722,10 @@ static int dax_pte_fault(struct vm_area_struct *vma, 
struct vm_fault *vmf,
                return VM_FAULT_LOCKED;
        }
 
-       /* Check we didn't race with a read fault installing a new page */
-       if (!page && major)
-               page = find_lock_page(mapping, vmf->pgoff);
+       if (current->needs_wmb)
+               wmb_pmem();
 
-       if (page) {
-               unmap_mapping_range(mapping, vmf->pgoff << PAGE_CACHE_SHIFT,
-                                                       PAGE_CACHE_SIZE, 0);
-               delete_from_page_cache(page);
-               unlock_page(page);
-               page_cache_release(page);
-               page = NULL;
-       }
+       error = vm_insert_mixed(vma, vaddr, pfn);
 
        /*
         * If we successfully insert the new mapping over an unwritten extent,
@@ -648,12 +737,11 @@ static int dax_pte_fault(struct vm_area_struct *vma, 
struct vm_fault *vmf,
         * indicate what the callback should do via the uptodate variable, same
         * as for normal BH based IO completions.
         */
-       error = dax_insert_mapping(inode, &bh, vma, vmf);
        if (buffer_unwritten(&bh)) {
                if (complete_unwritten)
                        complete_unwritten(&bh, !error);
                else
-                       WARN_ON_ONCE(!(vmf->flags & FAULT_FLAG_WRITE));
+                       WARN_ON_ONCE(!write);
        }
 
  out:
@@ -673,12 +761,6 @@ static int dax_pte_fault(struct vm_area_struct *vma, 
struct vm_fault *vmf,
 }
 
 #ifdef CONFIG_TRANSPARENT_HUGEPAGE
-/*
- * The 'colour' (ie low bits) within a PMD of a page offset.  This comes up
- * more often than one might expect in the below function.
- */
-#define PG_PMD_COLOUR  ((PMD_SIZE >> PAGE_CACHE_SHIFT) - 1)
-
 static void __dax_dbg(struct buffer_head *bh, unsigned long address,
                const char *reason, const char *fn)
 {
@@ -697,98 +779,19 @@ static void __dax_dbg(struct buffer_head *bh, unsigned 
long address,
 
 #define dax_pmd_dbg(bh, address, reason)       __dax_dbg(bh, address, reason, 
"dax_pmd")
 
-static int dax_insert_pmd_mapping(struct inode *inode, struct buffer_head *bh,
-                       struct vm_area_struct *vma, struct vm_fault *vmf)
-{
-       int major = 0;
-       struct blk_dax_ctl dax = {
-               .sector = to_sector(bh, inode),
-               .size = PMD_SIZE,
-       };
-       struct block_device *bdev = bh->b_bdev;
-       bool write = vmf->flags & FAULT_FLAG_WRITE;
-       unsigned long address = (unsigned long)vmf->virtual_address;
-       long length = dax_map_atomic(bdev, &dax);
-
-       if (length < 0)
-               return VM_FAULT_SIGBUS;
-       if (length < PMD_SIZE) {
-               dax_pmd_dbg(bh, address, "dax-length too small");
-               goto unmap;
-       }
-
-       if (pfn_t_to_pfn(dax.pfn) & PG_PMD_COLOUR) {
-               dax_pmd_dbg(bh, address, "pfn unaligned");
-               goto unmap;
-       }
-
-       if (!pfn_t_devmap(dax.pfn)) {
-               dax_pmd_dbg(bh, address, "pfn not in memmap");
-               goto unmap;
-       }
-
-       if (buffer_unwritten(bh) || buffer_new(bh)) {
-               clear_pmem(dax.addr, PMD_SIZE);
-               wmb_pmem();
-               count_vm_event(PGMAJFAULT);
-               mem_cgroup_count_vm_event(vma->vm_mm, PGMAJFAULT);
-               major = VM_FAULT_MAJOR;
-       }
-       dax_unmap_atomic(bdev, &dax);
-
-       /*
-        * For PTE faults we insert a radix tree entry for reads, and leave
-        * it clean.  Then on the first write we dirty the radix tree entry
-        * via the dax_pfn_mkwrite() path.  This sequence allows the
-        * dax_pfn_mkwrite() call to be simpler and avoid a call into
-        * get_block() to translate the pgoff to a sector in order to be able
-        * to create a new radix tree entry.
-        *
-        * The PMD path doesn't have an equivalent to dax_pfn_mkwrite(),
-        * though, so for a read followed by a write we traverse all the way
-        * through dax_pmd_fault() twice.  This means we can just skip
-        * inserting a radix tree entry completely on the initial read and
-        * just wait until the write to insert a dirty entry.
-        */
-       if (write) {
-               int error = dax_radix_entry(vma->vm_file->f_mapping, vmf->pgoff,
-                                               dax.sector, true, true);
-               if (error) {
-                       dax_pmd_dbg(bh, address,
-                                       "PMD radix insertion failed");
-                       goto fallback;
-               }
-       }
-
-       dev_dbg(part_to_dev(bdev->bd_part),
-                       "%s: %s addr: %lx pfn: %lx sect: %llx\n",
-                       __func__, current->comm, address,
-                       pfn_t_to_pfn(dax.pfn),
-                       (unsigned long long) dax.sector);
-       return major | vmf_insert_pfn_pmd(vma, address, vmf->pmd,
-                                               dax.pfn, write);
- unmap:
-       dax_unmap_atomic(bdev, &dax);
- fallback:
-       count_vm_event(THP_FAULT_FALLBACK);
-       return VM_FAULT_FALLBACK;
-}
-
 static int dax_pmd_fault(struct vm_area_struct *vma, struct vm_fault *vmf,
                get_block_t get_block, dax_iodone_t complete_unwritten)
 {
-       struct file *file = vma->vm_file;
-       struct address_space *mapping = file->f_mapping;
+       struct address_space *mapping = vma->vm_file->f_mapping;
        struct inode *inode = mapping->host;
+       void *entry;
+       pfn_t pfn;
        struct buffer_head bh;
-       unsigned blkbits = inode->i_blkbits;
-       unsigned long address = (unsigned long)vmf->virtual_address;
-       unsigned long pmd_addr = address & PMD_MASK;
-       bool write = vmf->flags & FAULT_FLAG_WRITE;
+       unsigned long vaddr = (unsigned long)vmf->virtual_address;
+       unsigned long pmd_addr = vaddr & PMD_MASK;
        pgoff_t size;
-       sector_t block;
-       int result;
-       bool alloc = false;
+       int result = 0;
+       bool write = vmf->flags & FAULT_FLAG_WRITE;
 
        /* dax pmd mappings require pfn_t_devmap() */
        if (!IS_ENABLED(CONFIG_FS_DAX_PMD))
@@ -796,17 +799,17 @@ static int dax_pmd_fault(struct vm_area_struct *vma, 
struct vm_fault *vmf,
 
        /* Fall back to PTEs if we're going to COW */
        if (write && !(vma->vm_flags & VM_SHARED)) {
-               split_huge_pmd(vma, vmf->pmd, address);
-               dax_pmd_dbg(NULL, address, "cow write");
+               split_huge_pmd(vma, vmf->pmd, vaddr);
+               dax_pmd_dbg(NULL, vaddr, "cow write");
                return VM_FAULT_FALLBACK;
        }
        /* If the PMD would extend outside the VMA */
        if (pmd_addr < vma->vm_start) {
-               dax_pmd_dbg(NULL, address, "vma start unaligned");
+               dax_pmd_dbg(NULL, vaddr, "vma start unaligned");
                return VM_FAULT_FALLBACK;
        }
        if ((pmd_addr + PMD_SIZE) > vma->vm_end) {
-               dax_pmd_dbg(NULL, address, "vma end unaligned");
+               dax_pmd_dbg(NULL, vaddr, "vma end unaligned");
                return VM_FAULT_FALLBACK;
        }
 
@@ -815,76 +818,69 @@ static int dax_pmd_fault(struct vm_area_struct *vma, 
struct vm_fault *vmf,
                return VM_FAULT_SIGBUS;
        /* If the PMD would cover blocks out of the file */
        if ((vmf->pgoff | PG_PMD_COLOUR) >= size) {
-               dax_pmd_dbg(NULL, address,
+               dax_pmd_dbg(NULL, vaddr,
                                "offset + huge page size > file size");
                return VM_FAULT_FALLBACK;
        }
 
        memset(&bh, 0, sizeof(bh));
-       bh.b_bdev = inode->i_sb->s_bdev;
-       block = (sector_t)vmf->pgoff << (PAGE_CACHE_SHIFT - blkbits);
-
-       bh.b_size = PMD_SIZE;
-
-       if (get_block(inode, block, &bh, 0) != 0)
-               return VM_FAULT_SIGBUS;
-
-       if (!buffer_mapped(&bh) && write) {
-               if (get_block(inode, block, &bh, 1) != 0)
-                       return VM_FAULT_SIGBUS;
-               alloc = true;
-       }
 
-       /*
-        * If the filesystem isn't willing to tell us the length of a hole,
-        * just fall back to PTEs.  Calling get_block 512 times in a loop
-        * would be silly.
-        */
-       if (!buffer_size_valid(&bh) || bh.b_size < PMD_SIZE) {
-               dax_pmd_dbg(&bh, address, "allocated block too small");
-               return VM_FAULT_FALLBACK;
-       }
+       entry = find_get_entry(mapping, vmf->pgoff);
+       if (radix_tree_exceptional_entry(entry) &&
+                               RADIX_DAX_SIZE(entry) >= RADIX_DAX_PMD) {
+               pfn = radix_to_pfn_t(entry, vmf->pgoff);
+       } else {
+               int error = dax_create_pfns(mapping, vmf->pgoff, PMD_SIZE,
+                                       write, &pfn, get_block, &bh);
+               if (error < 0)
+                       goto fallback;
 
-       /*
-        * If we allocated new storage, make sure no process has any
-        * zero pages covering this hole
-        */
-       if (alloc) {
-               loff_t lstart = vmf->pgoff << PAGE_CACHE_SHIFT;
-               loff_t lend = lstart + PMD_SIZE - 1; /* inclusive */
+               if (error) {
+                       count_vm_event(PGMAJFAULT);
+                       mem_cgroup_count_vm_event(vma->vm_mm, PGMAJFAULT);
+                       result = VM_FAULT_MAJOR;
+                       error = 0;
+               }
 
-               truncate_pagecache_range(inode, lstart, lend);
+               /*
+                * We don't know if dax_create_pfns() was able to allocate
+                * a contiguous aligned chunk, or whether it was only able
+                * to do a partial allocation.
+                */
+               entry = find_get_entry(mapping, vmf->pgoff);
+               if (!radix_tree_exceptional_entry(entry) ||
+                               RADIX_DAX_SIZE(entry) < RADIX_DAX_PMD)
+                       goto fallback;
        }
 
-       if (!write && !buffer_mapped(&bh) && buffer_uptodate(&bh)) {
+       if (is_bad_pfn_t(pfn)) {
                spinlock_t *ptl;
                pmd_t entry, *pmd = vmf->pmd;
                struct page *zero_page = get_huge_zero_page();
 
                if (unlikely(!zero_page)) {
-                       dax_pmd_dbg(&bh, address, "no zero page");
+                       dax_pmd_dbg(&bh, vaddr, "no zero page");
                        goto fallback;
                }
 
                ptl = pmd_lock(vma->vm_mm, pmd);
                if (!pmd_none(*pmd)) {
                        spin_unlock(ptl);
-                       dax_pmd_dbg(&bh, address, "pmd already present");
+                       dax_pmd_dbg(&bh, vaddr, "pmd already present");
                        goto fallback;
                }
 
-               dev_dbg(part_to_dev(bh.b_bdev->bd_part),
-                               "%s: %s addr: %lx pfn: <zero> sect: %llx\n",
-                               __func__, current->comm, address,
-                               (unsigned long long) to_sector(&bh, inode));
-
                entry = mk_pmd(zero_page, vma->vm_page_prot);
                entry = pmd_mkhuge(entry);
                set_pmd_at(vma->vm_mm, pmd_addr, pmd, entry);
                result = VM_FAULT_NOPAGE;
                spin_unlock(ptl);
        } else {
-               result = dax_insert_pmd_mapping(inode, &bh, vma, vmf);
+               if (current->needs_wmb)
+                       wmb_pmem();
+
+               result |= vmf_insert_pfn_pmd(vma, vaddr, vmf->pmd, pfn,
+                                               write);
        }
 
  out:
@@ -907,80 +903,21 @@ static int dax_pmd_fault(struct vm_area_struct *vma, 
struct vm_fault *vmf,
 #endif /* !CONFIG_TRANSPARENT_HUGEPAGE */
 
 #ifdef CONFIG_HAVE_ARCH_TRANSPARENT_HUGEPAGE_PUD
-/*
- * The 'colour' (ie low bits) within a PUD of a page offset.  This comes up
- * more often than one might expect in the below function.
- */
-#define PG_PUD_COLOUR  ((PUD_SIZE >> PAGE_CACHE_SHIFT) - 1)
-
 #define dax_pud_dbg(bh, address, reason)       __dax_dbg(bh, address, reason, 
"dax_pud")
 
-static int dax_insert_pud_mapping(struct inode *inode, struct buffer_head *bh,
-                       struct vm_area_struct *vma, struct vm_fault *vmf)
-{
-       int major = 0;
-       struct blk_dax_ctl dax = {
-               .sector = to_sector(bh, inode),
-               .size = PUD_SIZE,
-       };
-       struct block_device *bdev = bh->b_bdev;
-       bool write = vmf->flags & FAULT_FLAG_WRITE;
-       unsigned long address = (unsigned long)vmf->virtual_address;
-       long length = dax_map_atomic(bdev, &dax);
-
-       if (length < 0)
-               return VM_FAULT_SIGBUS;
-       if (length < PUD_SIZE) {
-               dax_pud_dbg(bh, address, "dax-length too small");
-               goto unmap;
-       }
-       if (pfn_t_to_pfn(dax.pfn) & PG_PUD_COLOUR) {
-               dax_pud_dbg(bh, address, "pfn unaligned");
-               goto unmap;
-       }
-
-       if (!pfn_t_devmap(dax.pfn)) {
-               dax_pud_dbg(bh, address, "pfn not in memmap");
-               goto unmap;
-       }
-
-       if (buffer_unwritten(bh) || buffer_new(bh)) {
-               clear_pmem(dax.addr, PUD_SIZE);
-               wmb_pmem();
-               count_vm_event(PGMAJFAULT);
-               mem_cgroup_count_vm_event(vma->vm_mm, PGMAJFAULT);
-               major = VM_FAULT_MAJOR;
-       }
-       dax_unmap_atomic(bdev, &dax);
-
-       dev_dbg(part_to_dev(bdev->bd_part),
-                       "%s: %s addr: %lx pfn: %lx sect: %llx\n",
-                       __func__, current->comm, address,
-                       pfn_t_to_pfn(dax.pfn),
-                       (unsigned long long) dax.sector);
-       return major | vmf_insert_pfn_pud(vma, address, vmf->pud,
-                                               dax.pfn, write);
- unmap:
-       dax_unmap_atomic(bdev, &dax);
-       count_vm_event(THP_FAULT_FALLBACK);
-       return VM_FAULT_FALLBACK;
-}
-
 static int dax_pud_fault(struct vm_area_struct *vma, struct vm_fault *vmf,
                get_block_t get_block, dax_iodone_t complete_unwritten)
 {
-       struct file *file = vma->vm_file;
-       struct address_space *mapping = file->f_mapping;
+       struct address_space *mapping = vma->vm_file->f_mapping;
        struct inode *inode = mapping->host;
+       void *entry;
+       pfn_t pfn;
        struct buffer_head bh;
-       unsigned blkbits = inode->i_blkbits;
-       unsigned long address = (unsigned long)vmf->virtual_address;
-       unsigned long pud_addr = address & PUD_MASK;
-       bool write = vmf->flags & FAULT_FLAG_WRITE;
+       unsigned long vaddr = (unsigned long)vmf->virtual_address;
+       unsigned long pud_addr = vaddr & PUD_MASK;
        pgoff_t size;
-       sector_t block;
-       int result;
-       bool alloc = false;
+       int result = 0;
+       bool write = vmf->flags & FAULT_FLAG_WRITE;
 
        /* dax pud mappings require pfn_t_devmap() */
        if (!IS_ENABLED(CONFIG_FS_DAX_PMD))
@@ -988,17 +925,17 @@ static int dax_pud_fault(struct vm_area_struct *vma, 
struct vm_fault *vmf,
 
        /* Fall back to PTEs if we're going to COW */
        if (write && !(vma->vm_flags & VM_SHARED)) {
-               split_huge_pud(vma, vmf->pud, address);
-               dax_pud_dbg(NULL, address, "cow write");
+               split_huge_pud(vma, vmf->pud, vaddr);
+               dax_pud_dbg(NULL, vaddr, "cow write");
                return VM_FAULT_FALLBACK;
        }
        /* If the PUD would extend outside the VMA */
        if (pud_addr < vma->vm_start) {
-               dax_pud_dbg(NULL, address, "vma start unaligned");
+               dax_pud_dbg(NULL, vaddr, "vma start unaligned");
                return VM_FAULT_FALLBACK;
        }
        if ((pud_addr + PUD_SIZE) > vma->vm_end) {
-               dax_pud_dbg(NULL, address, "vma end unaligned");
+               dax_pud_dbg(NULL, vaddr, "vma end unaligned");
                return VM_FAULT_FALLBACK;
        }
 
@@ -1007,52 +944,50 @@ static int dax_pud_fault(struct vm_area_struct *vma, 
struct vm_fault *vmf,
                return VM_FAULT_SIGBUS;
        /* If the PUD would cover blocks out of the file */
        if ((vmf->pgoff | PG_PUD_COLOUR) >= size) {
-               dax_pud_dbg(NULL, address,
+               dax_pud_dbg(NULL, vaddr,
                                "offset + huge page size > file size");
                return VM_FAULT_FALLBACK;
        }
 
        memset(&bh, 0, sizeof(bh));
-       bh.b_bdev = inode->i_sb->s_bdev;
-       block = (sector_t)vmf->pgoff << (PAGE_CACHE_SHIFT - blkbits);
 
-       bh.b_size = PUD_SIZE;
-
-       if (get_block(inode, block, &bh, 0) != 0)
-               return VM_FAULT_SIGBUS;
-
-       if (!buffer_mapped(&bh) && write) {
-               if (get_block(inode, block, &bh, 1) != 0)
-                       return VM_FAULT_SIGBUS;
-               alloc = true;
-       }
-
-       /*
-        * If the filesystem isn't willing to tell us the length of a hole,
-        * just fall back to PMDs.  Calling get_block 512 times in a loop
-        * would be silly.
-        */
-       if (!buffer_size_valid(&bh) || bh.b_size < PUD_SIZE) {
-               dax_pud_dbg(&bh, address, "allocated block too small");
-               return VM_FAULT_FALLBACK;
-       }
+       entry = find_get_entry(mapping, vmf->pgoff);
+       if (radix_tree_exceptional_entry(entry) &&
+                               RADIX_DAX_SIZE(entry) >= RADIX_DAX_PUD) {
+               pfn = radix_to_pfn_t(entry, vmf->pgoff);
+       } else {
+               int error = dax_create_pfns(mapping, vmf->pgoff, PUD_SIZE,
+                                       write, &pfn, get_block, &bh);
+               if (error < 0)
+                       goto fallback;
 
-       /*
-        * If we allocated new storage, make sure no process has any
-        * zero pages covering this hole
-        */
-       if (alloc) {
-               loff_t lstart = vmf->pgoff << PAGE_CACHE_SHIFT;
-               loff_t lend = lstart + PUD_SIZE - 1; /* inclusive */
+               if (error) {
+                       count_vm_event(PGMAJFAULT);
+                       mem_cgroup_count_vm_event(vma->vm_mm, PGMAJFAULT);
+                       result = VM_FAULT_MAJOR;
+                       error = 0;
+               }
 
-               truncate_pagecache_range(inode, lstart, lend);
+               /*
+                * We don't know if dax_create_pfns() was able to allocate
+                * a contiguous aligned chunk, or whether it was only able
+                * to do a partial allocation.
+                */
+               entry = find_get_entry(mapping, vmf->pgoff);
+               if (!radix_tree_exceptional_entry(entry) ||
+                               RADIX_DAX_SIZE(entry) < RADIX_DAX_PUD)
+                       goto fallback;
        }
 
-       if (!write && !buffer_mapped(&bh) && buffer_uptodate(&bh)) {
-               dax_pud_dbg(&bh, address, "no zero page");
+       if (is_bad_pfn_t(pfn)) {
+               dax_pud_dbg(&bh, vaddr, "no zero page");
                goto fallback;
        } else {
-               result = dax_insert_pud_mapping(inode, &bh, vma, vmf);
+               if (current->needs_wmb)
+                       wmb_pmem();
+
+               result |= vmf_insert_pfn_pud(vma, vaddr, vmf->pud, pfn,
+                                               write);
        }
 
  out:
@@ -1113,17 +1048,13 @@ EXPORT_SYMBOL_GPL(dax_fault);
  */
 int dax_pfn_mkwrite(struct vm_area_struct *vma, struct vm_fault *vmf)
 {
-       struct file *file = vma->vm_file;
+       struct address_space *mapping = vma->vm_file->f_mapping;
+
+       spin_lock_irq(&mapping->tree_lock);
+       radix_tree_tag_set(&mapping->page_tree, vmf->pgoff,
+                                                       PAGECACHE_TAG_DIRTY);
+       spin_unlock_irq(&mapping->tree_lock);
 
-       /*
-        * We pass NO_SECTOR to dax_radix_entry() because we expect that a
-        * RADIX_DAX_PTE entry already exists in the radix tree from a
-        * previous call to dax_fault().  We just want to look up that PTE
-        * entry using vmf->pgoff and make sure the dirty tag is set.  This
-        * saves us from having to make a call to get_block() here to look
-        * up the sector.
-        */
-       dax_radix_entry(file->f_mapping, vmf->pgoff, NO_SECTOR, false, true);
        return VM_FAULT_NOPAGE;
 }
 EXPORT_SYMBOL_GPL(dax_pfn_mkwrite);
diff --git a/include/linux/dax.h b/include/linux/dax.h
index 8e58c36..0a6505d 100644
--- a/include/linux/dax.h
+++ b/include/linux/dax.h
@@ -5,9 +5,10 @@
 #include <linux/mm.h>
 #include <asm/pgtable.h>
 
+int dax_clear_blocks(struct block_device *, sector_t sector, long size);
+
 ssize_t dax_do_io(struct kiocb *, struct inode *, struct iov_iter *, loff_t,
                  get_block_t, dio_iodone_t, int flags);
-int dax_clear_blocks(struct inode *, sector_t block, long size);
 int dax_zero_page_range(struct inode *, loff_t from, unsigned len, 
get_block_t);
 int dax_truncate_page(struct inode *, loff_t from, get_block_t);
 int dax_fault(struct vm_area_struct *, struct vm_fault *, get_block_t,
diff --git a/include/linux/pfn_t.h b/include/linux/pfn_t.h
index 07d18a8..95a7b50 100644
--- a/include/linux/pfn_t.h
+++ b/include/linux/pfn_t.h
@@ -8,21 +8,46 @@
  * PFN_SG_LAST - pfn references a page and is the last scatterlist entry
  * PFN_DEV - pfn is not covered by system memmap by default
  * PFN_MAP - pfn has a dynamic page mapping established by a device driver
+ *
+ * Note that DAX uses the same format for its radix tree entries.  The
+ * bottom two bits are used by the radix tree.
  */
-#define PFN_FLAGS_MASK (((unsigned long) ~PAGE_MASK) \
-               << (BITS_PER_LONG - PAGE_SHIFT))
-#define PFN_SG_CHAIN (1UL << (BITS_PER_LONG - 1))
-#define PFN_SG_LAST (1UL << (BITS_PER_LONG - 2))
-#define PFN_DEV (1UL << (BITS_PER_LONG - 3))
-#define PFN_MAP (1UL << (BITS_PER_LONG - 4))
+#define PFN_FLAG_BITS  4
+#define PFN_FLAGS_MASK ((1 << PFN_FLAG_BITS) - 1)
+#define PFN_SG_CHAIN   0x1UL
+#define PFN_SG_LAST    0x2UL
+#define PFN_DEV                0x4UL
+#define PFN_MAP                0x8UL
 
 static inline pfn_t __pfn_to_pfn_t(unsigned long pfn, unsigned long flags)
 {
-       pfn_t pfn_t = { .val = pfn | (flags & PFN_FLAGS_MASK), };
+       pfn_t pfn_t = { .val = (pfn << PFN_FLAG_BITS) |
+                                       (flags & PFN_FLAGS_MASK), };
 
        return pfn_t;
 }
 
+static inline __must_check pfn_t pfn_t_add(const pfn_t pfn, int val)
+{
+       pfn_t tmp = pfn;
+       tmp.val += val << PFN_FLAG_BITS;
+       return tmp;
+}
+       
+/*
+ * It makes no sense to have both SG_CHAIN and SG_LAST set, so we could
+ * encode an errno in here if we need to.  Note that you can't put a
+ * bad_pfn_t in the radix tree because the radix tree uses the bottom bit
+ * for its own purposes.
+ */
+#define bad_pfn_t      ((pfn_t) { .val = -1 })
+
+static inline bool is_bad_pfn_t(pfn_t pfn)
+{
+       return ((pfn.val & (PFN_SG_CHAIN | PFN_SG_LAST)) ==
+                         (PFN_SG_CHAIN | PFN_SG_LAST));
+}
+
 /* a default pfn to pfn_t conversion assumes that @pfn is pfn_valid() */
 static inline pfn_t pfn_to_pfn_t(unsigned long pfn)
 {
@@ -38,7 +63,7 @@ static inline bool pfn_t_has_page(pfn_t pfn)
 
 static inline unsigned long pfn_t_to_pfn(pfn_t pfn)
 {
-       return pfn.val & ~PFN_FLAGS_MASK;
+       return pfn.val >> PFN_FLAG_BITS;
 }
 
 static inline struct page *pfn_t_to_page(pfn_t pfn)
diff --git a/include/linux/radix-tree.h b/include/linux/radix-tree.h
index 7c88ad1..57e7d87 100644
--- a/include/linux/radix-tree.h
+++ b/include/linux/radix-tree.h
@@ -51,15 +51,6 @@
 #define RADIX_TREE_EXCEPTIONAL_ENTRY   2
 #define RADIX_TREE_EXCEPTIONAL_SHIFT   2
 
-#define RADIX_DAX_MASK 0xf
-#define RADIX_DAX_SHIFT        4
-#define RADIX_DAX_PTE  (0x4 | RADIX_TREE_EXCEPTIONAL_ENTRY)
-#define RADIX_DAX_PMD  (0x8 | RADIX_TREE_EXCEPTIONAL_ENTRY)
-#define RADIX_DAX_TYPE(entry) ((unsigned long)entry & RADIX_DAX_MASK)
-#define RADIX_DAX_SECTOR(entry) (((unsigned long)entry >> RADIX_DAX_SHIFT))
-#define RADIX_DAX_ENTRY(sector, pmd) ((void *)((unsigned long)sector << \
-               RADIX_DAX_SHIFT | (pmd ? RADIX_DAX_PMD : RADIX_DAX_PTE)))
-
 static inline int radix_tree_is_indirect_ptr(void *ptr)
 {
        return (int)((unsigned long)ptr & RADIX_TREE_INDIRECT_PTR);
diff --git a/include/linux/sched.h b/include/linux/sched.h
index 6e95d8a..2cdfe76 100644
--- a/include/linux/sched.h
+++ b/include/linux/sched.h
@@ -1476,6 +1476,7 @@ struct task_struct {
        /* unserialized, strictly 'current' */
        unsigned in_execve:1; /* bit to tell LSMs we're in execve */
        unsigned in_iowait:1;
+       unsigned needs_wmb:1;
 #ifdef CONFIG_MEMCG
        unsigned memcg_may_oom:1;
 #ifndef CONFIG_SLOB
-- 
2.7.0.rc3

Reply via email to