KASAN reports a slab use-after-free from virtio_pmem_host_ack(). It happens
when it wakes a request that has already been freed by the submitter.

This happens when the request token is still reachable via the virtqueue,
but virtio_pmem_flush() returns and frees it.

Fix the token lifetime by refcounting struct virtio_pmem_request.
virtio_pmem_flush() holds a submitter reference, and the virtqueue holds an
extra reference once the request is queued. The completion path drops the
virtqueue reference, and the submitter drops its reference before
returning.

Signed-off-by: Li Chen <[email protected]>
---
 drivers/nvdimm/nd_virtio.c   | 34 +++++++++++++++++++++++++++++-----
 drivers/nvdimm/virtio_pmem.h |  2 ++
 2 files changed, 31 insertions(+), 5 deletions(-)

diff --git a/drivers/nvdimm/nd_virtio.c b/drivers/nvdimm/nd_virtio.c
index 6f9890361d0b..d0385d4646f2 100644
--- a/drivers/nvdimm/nd_virtio.c
+++ b/drivers/nvdimm/nd_virtio.c
@@ -9,6 +9,14 @@
 #include "virtio_pmem.h"
 #include "nd.h"
 
+static void virtio_pmem_req_release(struct kref *kref)
+{
+       struct virtio_pmem_request *req;
+
+       req = container_of(kref, struct virtio_pmem_request, kref);
+       kfree(req);
+}
+
 static void virtio_pmem_wake_one_waiter(struct virtio_pmem *vpmem)
 {
        struct virtio_pmem_request *req_buf;
@@ -36,6 +44,7 @@ void virtio_pmem_host_ack(struct virtqueue *vq)
                virtio_pmem_wake_one_waiter(vpmem);
                WRITE_ONCE(req_data->done, true);
                wake_up(&req_data->host_acked);
+               kref_put(&req_data->kref, virtio_pmem_req_release);
        }
        spin_unlock_irqrestore(&vpmem->pmem_lock, flags);
 }
@@ -65,6 +74,7 @@ static int virtio_pmem_flush(struct nd_region *nd_region)
        if (!req_data)
                return -ENOMEM;
 
+       kref_init(&req_data->kref);
        WRITE_ONCE(req_data->done, false);
        init_waitqueue_head(&req_data->host_acked);
        init_waitqueue_head(&req_data->wq_buf);
@@ -82,10 +92,23 @@ static int virtio_pmem_flush(struct nd_region *nd_region)
          * to req_list and wait for host_ack to wake us up when free
          * slots are available.
          */
-       while ((err = virtqueue_add_sgs(vpmem->req_vq, sgs, 1, 1, req_data,
-                                       GFP_ATOMIC)) == -ENOSPC) {
-
-               dev_info(&vdev->dev, "failed to send command to virtio pmem 
device, no free slots in the virtqueue\n");
+       for (;;) {
+               err = virtqueue_add_sgs(vpmem->req_vq, sgs, 1, 1, req_data,
+                                       GFP_ATOMIC);
+               if (!err) {
+                       /*
+                        * Take the virtqueue reference while @pmem_lock is
+                        * held so completion cannot run concurrently.
+                        */
+                       kref_get(&req_data->kref);
+                       break;
+               }
+
+               if (err != -ENOSPC)
+                       break;
+
+               dev_info_ratelimited(&vdev->dev,
+                                    "failed to send command to virtio pmem 
device, no free slots in the virtqueue\n");
                WRITE_ONCE(req_data->wq_buf_avail, false);
                list_add_tail(&req_data->list, &vpmem->req_list);
                spin_unlock_irqrestore(&vpmem->pmem_lock, flags);
@@ -94,6 +117,7 @@ static int virtio_pmem_flush(struct nd_region *nd_region)
                wait_event(req_data->wq_buf, READ_ONCE(req_data->wq_buf_avail));
                spin_lock_irqsave(&vpmem->pmem_lock, flags);
        }
+
        err1 = virtqueue_kick(vpmem->req_vq);
        spin_unlock_irqrestore(&vpmem->pmem_lock, flags);
        /*
@@ -109,7 +133,7 @@ static int virtio_pmem_flush(struct nd_region *nd_region)
                err = le32_to_cpu(req_data->resp.ret);
        }
 
-       kfree(req_data);
+       kref_put(&req_data->kref, virtio_pmem_req_release);
        return err;
 };
 
diff --git a/drivers/nvdimm/virtio_pmem.h b/drivers/nvdimm/virtio_pmem.h
index 0dddefe594c4..fc8f613f8f28 100644
--- a/drivers/nvdimm/virtio_pmem.h
+++ b/drivers/nvdimm/virtio_pmem.h
@@ -12,10 +12,12 @@
 
 #include <linux/module.h>
 #include <uapi/linux/virtio_pmem.h>
+#include <linux/kref.h>
 #include <linux/libnvdimm.h>
 #include <linux/spinlock.h>
 
 struct virtio_pmem_request {
+       struct kref kref;
        struct virtio_pmem_req req;
        struct virtio_pmem_resp resp;
 
-- 
2.51.0


Reply via email to