If a vop virtio device's interrupt is shared with others, there is a
window where the interrupt can hit and the ->vqs list can be attempted
to be traversed before register_virtio_device() has initialized it,
leading to a NULL pointer dereference in vop_virtio_intr_handler().

Fix this by keeping a local list of virtqueues in this driver and using
that instead of the list inside the struct virtio_device, similar to how
virtio-pci handles this.

Signed-off-by: Vincent Whitchurch <vincent.whitchu...@axis.com>
---
 drivers/misc/mic/vop/vop_main.c | 42 ++++++++++++++++++++++++++++++---
 1 file changed, 39 insertions(+), 3 deletions(-)

diff --git a/drivers/misc/mic/vop/vop_main.c b/drivers/misc/mic/vop/vop_main.c
index e37b2c2152a2..6764feea6f55 100644
--- a/drivers/misc/mic/vop/vop_main.c
+++ b/drivers/misc/mic/vop/vop_main.c
@@ -40,6 +40,11 @@
 
 #define VOP_MAX_VRINGS 4
 
+struct vop_vq {
+       struct virtqueue *vq;
+       struct list_head list;
+};
+
 /*
  * _vop_vdev - Allocated per virtio device instance injected by the peer.
  *
@@ -68,6 +73,9 @@ struct _vop_vdev {
        int used_size[VOP_MAX_VRINGS];
        struct completion reset_done;
        struct mic_irq *virtio_cookie;
+       struct vop_vq *vqs;
+       struct list_head virtqueues;
+       spinlock_t virtqueues_lock;
        int c2h_vdev_db;
        int h2c_vdev_db;
        int dnode;
@@ -264,6 +272,12 @@ static void vop_del_vq(struct virtqueue *vq, int n)
 {
        struct _vop_vdev *vdev = to_vopvdev(vq->vdev);
        struct vop_device *vpdev = vdev->vpdev;
+       struct vop_vq *vopvq = &vdev->vqs[n];
+       unsigned long flags;
+
+       spin_lock_irqsave(&vdev->virtqueues_lock, flags);
+       list_del(&vopvq->list);
+       spin_unlock_irqrestore(&vdev->virtqueues_lock, flags);
 
        dma_unmap_single(&vpdev->dev, vdev->used[n],
                         vdev->used_size[n], DMA_BIDIRECTIONAL);
@@ -284,6 +298,9 @@ static void vop_del_vqs(struct virtio_device *dev)
 
        list_for_each_entry_safe(vq, n, &dev->vqs, list)
                vop_del_vq(vq, idx++);
+
+       kfree(vdev->vqs);
+       vdev->vqs = NULL;
 }
 
 static struct virtqueue *vop_new_virtqueue(unsigned int index,
@@ -411,7 +428,13 @@ static int vop_find_vqs(struct virtio_device *dev, 
unsigned nvqs,
        if (nvqs > ioread8(&vdev->desc->num_vq))
                return -ENOENT;
 
+       vdev->vqs = kcalloc(nvqs, sizeof(*vdev->vqs), GFP_KERNEL);
+       if (!vdev->vqs)
+               return -ENOMEM;
+
        for (i = 0; i < nvqs; ++i) {
+               unsigned long flags;
+
                if (!names[i]) {
                        vqs[i] = NULL;
                        continue;
@@ -425,6 +448,12 @@ static int vop_find_vqs(struct virtio_device *dev, 
unsigned nvqs,
                        err = PTR_ERR(vqs[i]);
                        goto error;
                }
+
+               vdev->vqs[i].vq = vqs[i];
+
+               spin_lock_irqsave(&vdev->virtqueues_lock, flags);
+               list_add(&vdev->vqs[i].list, &vdev->virtqueues);
+               spin_unlock_irqrestore(&vdev->virtqueues_lock, flags);
        }
 
        iowrite8(1, &dc->used_address_updated);
@@ -468,13 +497,17 @@ static struct virtio_config_ops vop_vq_config_ops = {
 
 static irqreturn_t vop_virtio_intr_handler(int irq, void *data)
 {
+       unsigned long flags;
        struct _vop_vdev *vdev = data;
        struct vop_device *vpdev = vdev->vpdev;
-       struct virtqueue *vq;
+       struct vop_vq *vopvq;
 
        vpdev->hw_ops->ack_interrupt(vpdev, vdev->h2c_vdev_db);
-       list_for_each_entry(vq, &vdev->vdev.vqs, list)
-               vring_interrupt(0, vq);
+
+       spin_lock_irqsave(&vdev->virtqueues_lock, flags);
+       list_for_each_entry(vopvq, &vdev->virtqueues, list)
+               vring_interrupt(0, vopvq->vq);
+       spin_unlock_irqrestore(&vdev->virtqueues_lock, flags);
 
        return IRQ_HANDLED;
 }
@@ -516,6 +549,9 @@ static int _vop_add_device(struct mic_device_desc __iomem 
*d,
        vdev->vdev.priv = (void *)(unsigned long)dnode;
        init_completion(&vdev->reset_done);
 
+       INIT_LIST_HEAD(&vdev->virtqueues);
+       spin_lock_init(&vdev->virtqueues_lock);
+
        vdev->h2c_vdev_db = vpdev->hw_ops->next_db(vpdev);
        vdev->virtio_cookie = vpdev->hw_ops->request_irq(vpdev,
                        vop_virtio_intr_handler, "virtio intr",
-- 
2.20.0

Reply via email to