From: John Groves <[email protected]>

After the dynamic path set dev_dax->pgmap, any later probe failure left
dev_dax->pgmap dangling: devres frees the devm_kzalloc'd pgmap on probe
failure, and subsequent probe attempts would hit the "dynamic-dax with
pre-populated page map" check and fail permanently.

Factor pgmap acquisition out into fsdev_acquire_pgmap(), and defer the
dev_dax->pgmap assignment until probe can no longer fail. A failed probe
now never publishes the pointer at all, so there is nothing to unwind.
This also matches kill_dev_dax(), which already clears the dynamic pgmap
pointer on unbind: dev_dax->pgmap is now non-NULL only while the pgmap
is actually valid.

Refactor suggested by Dave Jiang.

Fixes: d5406bd458b0a ("dax: add fsdev.c driver for fs-dax on character dax")
Signed-off-by: John Groves <[email protected]>
---
 drivers/dax/fsdev.c | 77 ++++++++++++++++++++++++++++-----------------
 1 file changed, 49 insertions(+), 28 deletions(-)

diff --git a/drivers/dax/fsdev.c b/drivers/dax/fsdev.c
index dbd722ed7ab05..0fd5e1293d725 100644
--- a/drivers/dax/fsdev.c
+++ b/drivers/dax/fsdev.c
@@ -219,47 +219,62 @@ static const struct file_operations fsdev_fops = {
        .release = fsdev_release,
 };
 
-static int fsdev_dax_probe(struct dev_dax *dev_dax)
+/*
+ * Acquire the dev_pagemap for probe: the static (pre-populated) one if
+ * present, or a devm-allocated one for the dynamic case. Note that
+ * dev_dax->pgmap is not set here; fsdev_dax_probe() sets it only once
+ * probe succeeds, so a failed probe never leaves a dangling pointer
+ * to a devres-freed pgmap.
+ */
+static struct dev_pagemap *fsdev_acquire_pgmap(struct dev_dax *dev_dax)
 {
-       struct dax_device *dax_dev = dev_dax->dax_dev;
        struct device *dev = &dev_dax->dev;
        struct dev_pagemap *pgmap;
-       struct inode *inode;
-       u64 data_offset = 0;
-       struct cdev *cdev;
-       void *addr;
-       int rc, i;
+       size_t pgmap_size;
 
        if (static_dev_dax(dev_dax)) {
                if (dev_dax->nr_range > 1) {
-                       dev_warn(dev, "static pgmap / multi-range device 
conflict\n");
-                       return -EINVAL;
+                       dev_warn(dev,
+                                "static pgmap / multi-range device 
conflict\n");
+                       return ERR_PTR(-EINVAL);
                }
 
                pgmap = dev_dax->pgmap;
                pgmap->vmemmap_shift = 0;
-       } else {
-               size_t pgmap_size;
+               return pgmap;
+       }
 
-               if (dev_dax->pgmap) {
-                       dev_warn(dev, "dynamic-dax with pre-populated page 
map\n");
-                       return -EINVAL;
-               }
+       if (dev_dax->pgmap) {
+               dev_warn(dev, "dynamic-dax with pre-populated page map\n");
+               return ERR_PTR(-EINVAL);
+       }
 
-               pgmap_size = struct_size(pgmap, ranges, dev_dax->nr_range - 1);
-               pgmap = devm_kzalloc(dev, pgmap_size, GFP_KERNEL);
-               if (!pgmap)
-                       return -ENOMEM;
+       pgmap_size = struct_size(pgmap, ranges, dev_dax->nr_range - 1);
+       pgmap = devm_kzalloc(dev, pgmap_size, GFP_KERNEL);
+       if (!pgmap)
+               return ERR_PTR(-ENOMEM);
 
-               pgmap->nr_range = dev_dax->nr_range;
-               dev_dax->pgmap = pgmap;
+       pgmap->nr_range = dev_dax->nr_range;
+       for (int i = 0; i < dev_dax->nr_range; i++)
+               pgmap->ranges[i] = dev_dax->ranges[i].range;
 
-               for (i = 0; i < dev_dax->nr_range; i++) {
-                       struct range *range = &dev_dax->ranges[i].range;
+       return pgmap;
+}
 
-                       pgmap->ranges[i] = *range;
-               }
-       }
+static int fsdev_dax_probe(struct dev_dax *dev_dax)
+{
+       struct dax_device *dax_dev = dev_dax->dax_dev;
+       struct device *dev = &dev_dax->dev;
+       struct dev_pagemap *pgmap;
+       struct inode *inode;
+       u64 data_offset = 0;
+       struct cdev *cdev;
+       void *addr;
+       int rc, i;
+
+       pgmap = fsdev_acquire_pgmap(dev_dax);
+       if (IS_ERR(pgmap))
+               return PTR_ERR(pgmap);
 
        for (i = 0; i < dev_dax->nr_range; i++) {
                struct range *range = &dev_dax->ranges[i].range;
@@ -306,7 +321,7 @@ static int fsdev_dax_probe(struct dev_dax *dev_dax)
        /* Detect whether the data is at a non-zero offset into the memory */
        if (pgmap->range.start != dev_dax->ranges[0].range.start) {
                u64 phys = dev_dax->ranges[0].range.start;
-               u64 pgmap_phys = dev_dax->pgmap[0].range.start;
+               u64 pgmap_phys = pgmap[0].range.start;
 
                if (!WARN_ON(pgmap_phys > phys))
                        data_offset = phys - pgmap_phys;
@@ -339,7 +354,13 @@ static int fsdev_dax_probe(struct dev_dax *dev_dax)
                return rc;
 
        run_dax(dax_dev);
-       return devm_add_action_or_reset(dev, fsdev_kill, dev_dax);
+       rc = devm_add_action_or_reset(dev, fsdev_kill, dev_dax);
+       if (rc)
+               return rc;
+
+       /* Probe can no longer fail; expose the pgmap via dev_dax */
+       dev_dax->pgmap = pgmap;
+       return 0;
 }
 
 static struct dax_device_driver fsdev_dax_driver = {
-- 
2.53.0


Reply via email to