From: Johannes Berg <johannes.b...@intel.com>

This code can be called deep in the IRQ handling, for
example, and then cannot normally use kmalloc(). Have
its own pre-allocated memory and use from there instead
so this doesn't occur. Only in the (very rare) case of
memcpy_toio() we'd still need to allocate memory.

Signed-off-by: Johannes Berg <johannes.b...@intel.com>
---
 arch/um/drivers/virt-pci.c | 196 +++++++++++++++++++------------------
 1 file changed, 101 insertions(+), 95 deletions(-)

diff --git a/arch/um/drivers/virt-pci.c b/arch/um/drivers/virt-pci.c
index 744e7f31e8ef..dd5580f975cc 100644
--- a/arch/um/drivers/virt-pci.c
+++ b/arch/um/drivers/virt-pci.c
@@ -25,8 +25,10 @@
 #define MAX_IRQ_MSG_SIZE (sizeof(struct virtio_pcidev_msg) + sizeof(u32))
 #define NUM_IRQ_MSGS   10
 
-#define HANDLE_NO_FREE(ptr) ((void *)((unsigned long)(ptr) | 1))
-#define HANDLE_IS_NO_FREE(ptr) ((unsigned long)(ptr) & 1)
+struct um_pci_message_buffer {
+       struct virtio_pcidev_msg hdr;
+       u8 data[8];
+};
 
 struct um_pci_device {
        struct virtio_device *vdev;
@@ -36,6 +38,11 @@ struct um_pci_device {
 
        struct virtqueue *cmd_vq, *irq_vq;
 
+#define UM_PCI_WRITE_BUFS      20
+       struct um_pci_message_buffer bufs[UM_PCI_WRITE_BUFS + 1];
+       void *extra_ptrs[UM_PCI_WRITE_BUFS + 1];
+       DECLARE_BITMAP(used_bufs, UM_PCI_WRITE_BUFS);
+
 #define UM_PCI_STAT_WAITING    0
        unsigned long status;
 
@@ -61,12 +68,40 @@ static unsigned long 
um_pci_msi_used[BITS_TO_LONGS(MAX_MSI_VECTORS)];
 static unsigned int um_pci_max_delay_us = 40000;
 module_param_named(max_delay_us, um_pci_max_delay_us, uint, 0644);
 
-struct um_pci_message_buffer {
-       struct virtio_pcidev_msg hdr;
-       u8 data[8];
-};
+static int um_pci_get_buf(struct um_pci_device *dev, bool *posted)
+{
+       int i;
 
-static struct um_pci_message_buffer __percpu *um_pci_msg_bufs;
+       for (i = 0; i < UM_PCI_WRITE_BUFS; i++) {
+               if (!test_and_set_bit(i, dev->used_bufs))
+                       return i;
+       }
+
+       *posted = false;
+       return UM_PCI_WRITE_BUFS;
+}
+
+static void um_pci_free_buf(struct um_pci_device *dev, void *buf)
+{
+       int i;
+
+       if (buf == &dev->bufs[UM_PCI_WRITE_BUFS]) {
+               kfree(dev->extra_ptrs[UM_PCI_WRITE_BUFS]);
+               dev->extra_ptrs[UM_PCI_WRITE_BUFS] = NULL;
+               return;
+       }
+
+       for (i = 0; i < UM_PCI_WRITE_BUFS; i++) {
+               if (buf == &dev->bufs[i]) {
+                       kfree(dev->extra_ptrs[i]);
+                       dev->extra_ptrs[i] = NULL;
+                       WARN_ON(!test_and_clear_bit(i, dev->used_bufs));
+                       return;
+               }
+       }
+
+       WARN_ON(1);
+}
 
 static int um_pci_send_cmd(struct um_pci_device *dev,
                           struct virtio_pcidev_msg *cmd,
@@ -82,7 +117,9 @@ static int um_pci_send_cmd(struct um_pci_device *dev,
        };
        struct um_pci_message_buffer *buf;
        int delay_count = 0;
+       bool bounce_out;
        int ret, len;
+       int buf_idx;
        bool posted;
 
        if (WARN_ON(cmd_size < sizeof(*cmd) || cmd_size > sizeof(*buf)))
@@ -101,26 +138,28 @@ static int um_pci_send_cmd(struct um_pci_device *dev,
                break;
        }
 
-       buf = get_cpu_var(um_pci_msg_bufs);
-       if (buf)
-               memcpy(buf, cmd, cmd_size);
+       bounce_out = !posted && cmd_size <= sizeof(*cmd) &&
+                    out && out_size <= sizeof(buf->data);
 
-       if (posted) {
-               u8 *ncmd = kmalloc(cmd_size + extra_size, GFP_ATOMIC);
+       buf_idx = um_pci_get_buf(dev, &posted);
+       buf = &dev->bufs[buf_idx];
+       memcpy(buf, cmd, cmd_size);
 
-               if (ncmd) {
-                       memcpy(ncmd, cmd, cmd_size);
-                       if (extra)
-                               memcpy(ncmd + cmd_size, extra, extra_size);
-                       cmd = (void *)ncmd;
-                       cmd_size += extra_size;
-                       extra = NULL;
-                       extra_size = 0;
-               } else {
-                       /* try without allocating memory */
-                       posted = false;
-                       cmd = (void *)buf;
+       if (posted && extra && extra_size > sizeof(buf) - cmd_size) {
+               dev->extra_ptrs[buf_idx] = kmemdup(extra, extra_size,
+                                                  GFP_ATOMIC);
+
+               if (!dev->extra_ptrs[buf_idx]) {
+                       um_pci_free_buf(dev, buf);
+                       return -ENOMEM;
                }
+               extra = dev->extra_ptrs[buf_idx];
+       } else if (extra && extra_size <= sizeof(buf) - cmd_size) {
+               memcpy((u8 *)buf + cmd_size, extra, extra_size);
+               cmd_size += extra_size;
+               extra_size = 0;
+               extra = NULL;
+               cmd = (void *)buf;
        } else {
                cmd = (void *)buf;
        }
@@ -128,39 +167,40 @@ static int um_pci_send_cmd(struct um_pci_device *dev,
        sg_init_one(&out_sg, cmd, cmd_size);
        if (extra)
                sg_init_one(&extra_sg, extra, extra_size);
-       if (out)
+       /* allow stack for small buffers */
+       if (bounce_out)
+               sg_init_one(&in_sg, buf->data, out_size);
+       else if (out)
                sg_init_one(&in_sg, out, out_size);
 
        /* add to internal virtio queue */
        ret = virtqueue_add_sgs(dev->cmd_vq, sgs_list,
                                extra ? 2 : 1,
                                out ? 1 : 0,
-                               posted ? cmd : HANDLE_NO_FREE(cmd),
-                               GFP_ATOMIC);
+                               cmd, GFP_ATOMIC);
        if (ret) {
-               if (posted)
-                       kfree(cmd);
-               goto out;
+               um_pci_free_buf(dev, buf);
+               return ret;
        }
 
        if (posted) {
                virtqueue_kick(dev->cmd_vq);
-               ret = 0;
-               goto out;
+               return 0;
        }
 
        /* kick and poll for getting a response on the queue */
        set_bit(UM_PCI_STAT_WAITING, &dev->status);
        virtqueue_kick(dev->cmd_vq);
+       ret = 0;
 
        while (1) {
                void *completed = virtqueue_get_buf(dev->cmd_vq, &len);
 
-               if (completed == HANDLE_NO_FREE(cmd))
+               if (completed == buf)
                        break;
 
-               if (completed && !HANDLE_IS_NO_FREE(completed))
-                       kfree(completed);
+               if (completed)
+                       um_pci_free_buf(dev, completed);
 
                if (WARN_ONCE(virtqueue_is_broken(dev->cmd_vq) ||
                              ++delay_count > um_pci_max_delay_us,
@@ -172,8 +212,11 @@ static int um_pci_send_cmd(struct um_pci_device *dev,
        }
        clear_bit(UM_PCI_STAT_WAITING, &dev->status);
 
-out:
-       put_cpu_var(um_pci_msg_bufs);
+       if (bounce_out)
+               memcpy(out, buf->data, out_size);
+
+       um_pci_free_buf(dev, buf);
+
        return ret;
 }
 
@@ -187,20 +230,13 @@ static unsigned long um_pci_cfgspace_read(void *priv, 
unsigned int offset,
                .size = size,
                .addr = offset,
        };
-       /* buf->data is maximum size - we may only use parts of it */
-       struct um_pci_message_buffer *buf;
-       u8 *data;
-       unsigned long ret = ULONG_MAX;
-       size_t bytes = sizeof(buf->data);
+       /* max 8, we might not use it all */
+       u8 data[8];
 
        if (!dev)
                return ULONG_MAX;
 
-       buf = get_cpu_var(um_pci_msg_bufs);
-       data = buf->data;
-
-       if (buf)
-               memset(data, 0xff, bytes);
+       memset(data, 0xff, sizeof(data));
 
        switch (size) {
        case 1:
@@ -212,34 +248,26 @@ static unsigned long um_pci_cfgspace_read(void *priv, 
unsigned int offset,
                break;
        default:
                WARN(1, "invalid config space read size %d\n", size);
-               goto out;
+               return ULONG_MAX;
        }
 
-       if (um_pci_send_cmd(dev, &hdr, sizeof(hdr), NULL, 0, data, bytes))
-               goto out;
+       if (um_pci_send_cmd(dev, &hdr, sizeof(hdr), NULL, 0, data, size))
+               return ULONG_MAX;
 
        switch (size) {
        case 1:
-               ret = data[0];
-               break;
+               return data[0];
        case 2:
-               ret = le16_to_cpup((void *)data);
-               break;
+               return le16_to_cpup((void *)data);
        case 4:
-               ret = le32_to_cpup((void *)data);
-               break;
+               return le32_to_cpup((void *)data);
 #ifdef CONFIG_64BIT
        case 8:
-               ret = le64_to_cpup((void *)data);
-               break;
+               return le64_to_cpup((void *)data);
 #endif
        default:
-               break;
+               return ULONG_MAX;
        }
-
-out:
-       put_cpu_var(um_pci_msg_bufs);
-       return ret;
 }
 
 static void um_pci_cfgspace_write(void *priv, unsigned int offset, int size,
@@ -312,13 +340,8 @@ static void um_pci_bar_copy_from(void *priv, void *buffer,
 static unsigned long um_pci_bar_read(void *priv, unsigned int offset,
                                     int size)
 {
-       /* buf->data is maximum size - we may only use parts of it */
-       struct um_pci_message_buffer *buf;
-       u8 *data;
-       unsigned long ret = ULONG_MAX;
-
-       buf = get_cpu_var(um_pci_msg_bufs);
-       data = buf->data;
+       /* 8 is maximum size - we may only use parts of it */
+       u8 data[8];
 
        switch (size) {
        case 1:
@@ -330,33 +353,25 @@ static unsigned long um_pci_bar_read(void *priv, unsigned 
int offset,
                break;
        default:
                WARN(1, "invalid config space read size %d\n", size);
-               goto out;
+               return ULONG_MAX;
        }
 
        um_pci_bar_copy_from(priv, data, offset, size);
 
        switch (size) {
        case 1:
-               ret = data[0];
-               break;
+               return data[0];
        case 2:
-               ret = le16_to_cpup((void *)data);
-               break;
+               return le16_to_cpup((void *)data);
        case 4:
-               ret = le32_to_cpup((void *)data);
-               break;
+               return le32_to_cpup((void *)data);
 #ifdef CONFIG_64BIT
        case 8:
-               ret = le64_to_cpup((void *)data);
-               break;
+               return le64_to_cpup((void *)data);
 #endif
        default:
-               break;
+               return ULONG_MAX;
        }
-
-out:
-       put_cpu_var(um_pci_msg_bufs);
-       return ret;
 }
 
 static void um_pci_bar_copy_to(void *priv, unsigned int offset,
@@ -523,11 +538,8 @@ static void um_pci_cmd_vq_cb(struct virtqueue *vq)
        if (test_bit(UM_PCI_STAT_WAITING, &dev->status))
                return;
 
-       while ((cmd = virtqueue_get_buf(vq, &len))) {
-               if (WARN_ON(HANDLE_IS_NO_FREE(cmd)))
-                       continue;
-               kfree(cmd);
-       }
+       while ((cmd = virtqueue_get_buf(vq, &len)))
+               um_pci_free_buf(dev, cmd);
 }
 
 static void um_pci_irq_vq_cb(struct virtqueue *vq)
@@ -1006,10 +1018,6 @@ static int __init um_pci_init(void)
                 "No virtio device ID configured for PCI - no PCI support\n"))
                return 0;
 
-       um_pci_msg_bufs = alloc_percpu(struct um_pci_message_buffer);
-       if (!um_pci_msg_bufs)
-               return -ENOMEM;
-
        bridge = pci_alloc_host_bridge(0);
        if (!bridge) {
                err = -ENOMEM;
@@ -1070,7 +1078,6 @@ static int __init um_pci_init(void)
                pci_free_resource_list(&bridge->windows);
                pci_free_host_bridge(bridge);
        }
-       free_percpu(um_pci_msg_bufs);
        return err;
 }
 module_init(um_pci_init);
@@ -1082,6 +1089,5 @@ static void __exit um_pci_exit(void)
        irq_domain_remove(um_pci_inner_domain);
        pci_free_resource_list(&bridge->windows);
        pci_free_host_bridge(bridge);
-       free_percpu(um_pci_msg_bufs);
 }
 module_exit(um_pci_exit);
-- 
2.47.1


Reply via email to