On Wed, Nov 18, 2015 at 11:15:25AM +0000, Xie, Huawei wrote:
> On 11/18/2015 4:47 PM, Yuanhan Liu wrote:
> > On Wed, Nov 18, 2015 at 07:53:24AM +0000, Xie, Huawei wrote:
> > ...
> >>>   do {
> >>> +         if (vec_id >= BUF_VECTOR_MAX)
> >>> +                 break;
> >>> +
> >>>           next_desc = 0;
> >>>           len += vq->desc[idx].len;
> >>>           vq->buf_vec[vec_id].buf_addr = vq->desc[idx].addr;
> >>> @@ -519,6 +526,8 @@ virtio_dev_merge_rx(struct virtio_net *dev, uint16_t 
> >>> queue_id,
> >>>                                   goto merge_rx_exit;
> >>>                           } else {
> >>>                                   update_secure_len(vq, res_cur_idx, 
> >>> &secure_len, &vec_idx);
> >>> +                                 if (secure_len == 0)
> >>> +                                         goto merge_rx_exit;
> >> Why do we exit when secure_len is 0 rather than 1? :). Malicious guest
> > I confess it's not a proper fix. Making it return an error code, as Rich
> > suggested in early email, is better. It's generic enough, as we have to
> > check the vec_buf overflow here.
> >
> > BTW, can we move the vec_buf outside `struct vhost_virtqueue'? It makes
> > the structure huge.
> >
> >> could easily forge the desc len so that secure_len never reach pkt_len
> >> even it is not zero so that host enters into dead loop here.
> >> Generally speaking, we shouldn't fix for a specific issue,
> > Agreed.
> >
> >> and the
> >> security checks should be as few as possible.
> > Idealy, yes.
> >
> >> We need to consider
> >> refactor the code here for the generic fix.
> > What's your thougths?
> Maybe we merge the update_secure_len with the outside loop into a simple
> inline function, in which we consider both the max vector number and
> desc count to avoid trapped into dead loop. This functions returns a buf
> vec with which we could copy securely afterwards.

I agree that grouping them into a function makes the logic clearer, and
hence less error-prone.

I made a quick try. Comments?

        --yliu


---
diff --git a/lib/librte_vhost/vhost_rxtx.c b/lib/librte_vhost/vhost_rxtx.c
index 4fc35d1..e270fb1 100644
--- a/lib/librte_vhost/vhost_rxtx.c
+++ b/lib/librte_vhost/vhost_rxtx.c
@@ -439,32 +439,98 @@ copy_from_mbuf_to_vring(struct virtio_net *dev, uint32_t 
queue_id,
        return entry_success;
 }

-static inline void __attribute__((always_inline))
-update_secure_len(struct vhost_virtqueue *vq, uint32_t id,
-       uint32_t *secure_len, uint32_t *vec_idx)
+static inline int
+fill_vec_buf(struct vhost_virtqueue *vq, uint32_t avail_idx,
+            uint32_t *allocated, uint32_t *vec_idx)
 {
-       uint16_t wrapped_idx = id & (vq->size - 1);
-       uint32_t idx = vq->avail->ring[wrapped_idx];
-       uint8_t next_desc;
-       uint32_t len = *secure_len;
+       uint16_t idx = vq->avail->ring[avail_idx & (vq->size - 1)];
        uint32_t vec_id = *vec_idx;
+       uint32_t len    =  *allocated;
+
+       while (1) {
+               if (vec_id >= BUF_VECTOR_MAX)
+                       return -1;

-       do {
-               next_desc = 0;
                len += vq->desc[idx].len;
                vq->buf_vec[vec_id].buf_addr = vq->desc[idx].addr;
                vq->buf_vec[vec_id].buf_len = vq->desc[idx].len;
                vq->buf_vec[vec_id].desc_idx = idx;
                vec_id++;

-               if (vq->desc[idx].flags & VRING_DESC_F_NEXT) {
-                       idx = vq->desc[idx].next;
-                       next_desc = 1;
+               if ((vq->desc[idx].flags & VRING_DESC_F_NEXT) == 0)
+                       break;
+
+               idx = vq->desc[idx].next;
+       }
+
+       *allocated = len;
+       *vec_idx   = vec_id;
+
+       return 0;
+}
+
+/*
+ * As many data cores may want to access available buffers concurrently,
+ * they need to be reserved.
+ *
+ * Returns -1 on fail, 0 on success
+ */
+static inline int
+reserve_avail_buf(struct vhost_virtqueue *vq, uint32_t size,
+                 uint16_t *start, uint16_t *end)
+{
+       uint16_t res_base_idx;
+       uint16_t res_cur_idx;
+       uint16_t avail_idx;
+       uint32_t allocated;
+       uint32_t vec_idx;
+       uint16_t tries;
+
+again:
+       res_base_idx = vq->last_used_idx_res;
+       res_cur_idx  = res_base_idx;
+
+       allocated = 0;
+       vec_idx   = 0;
+       tries     = 0;
+       while (1) {
+               avail_idx = *((volatile uint16_t *)&vq->avail->idx);
+               if (unlikely(res_cur_idx == avail_idx)) {
+                       LOG_DEBUG(VHOST_DATA, "(%"PRIu64") Failed "
+                               "to get enough desc from vring\n",
+                               dev->device_fh);
+                       return -1;
                }
-       } while (next_desc);

-       *secure_len = len;
-       *vec_idx = vec_id;
+               if (fill_vec_buf(vq, res_cur_idx, &allocated, &vec_idx) < 0)
+                       return -1;
+
+               res_cur_idx++;
+               tries++;
+
+               if (allocated >= size)
+                       break;
+
+               /*
+                * if we tried all available ring items, and still
+                * can't get enough buf, it means something abnormal
+                * happened.
+                */
+               if (tries >= vq->size)
+                       return -1;
+       }
+
+       /*
+        * update vq->last_used_idx_res atomically.
+        * retry again if failed.
+        */
+       if (rte_atomic16_cmpset(&vq->last_used_idx_res,
+                               res_base_idx, res_cur_idx) == 0)
+               goto again;
+
+       *start = res_base_idx;
+       *end   = res_cur_idx;
+       return 0;
 }

 /*
@@ -476,9 +542,7 @@ virtio_dev_merge_rx(struct virtio_net *dev, uint16_t 
queue_id,
 {
        struct vhost_virtqueue *vq;
        uint32_t pkt_idx = 0, entry_success = 0;
-       uint16_t avail_idx;
-       uint16_t res_base_idx, res_cur_idx;
-       uint8_t success = 0;
+       uint16_t start, end;

        LOG_DEBUG(VHOST_DATA, "(%"PRIu64") virtio_dev_merge_rx()\n",
                dev->device_fh);
@@ -501,40 +565,11 @@ virtio_dev_merge_rx(struct virtio_net *dev, uint16_t 
queue_id,
        for (pkt_idx = 0; pkt_idx < count; pkt_idx++) {
                uint32_t pkt_len = pkts[pkt_idx]->pkt_len + vq->vhost_hlen;

-               do {
-                       /*
-                        * As many data cores may want access to available
-                        * buffers, they need to be reserved.
-                        */
-                       uint32_t secure_len = 0;
-                       uint32_t vec_idx = 0;
-
-                       res_base_idx = vq->last_used_idx_res;
-                       res_cur_idx = res_base_idx;
-
-                       do {
-                               avail_idx = *((volatile uint16_t 
*)&vq->avail->idx);
-                               if (unlikely(res_cur_idx == avail_idx)) {
-                                       LOG_DEBUG(VHOST_DATA,
-                                               "(%"PRIu64") Failed "
-                                               "to get enough desc from "
-                                               "vring\n",
-                                               dev->device_fh);
-                                       goto merge_rx_exit;
-                               } else {
-                                       update_secure_len(vq, res_cur_idx, 
&secure_len, &vec_idx);
-                                       res_cur_idx++;
-                               }
-                       } while (pkt_len > secure_len);
-
-                       /* vq->last_used_idx_res is atomically updated. */
-                       success = rte_atomic16_cmpset(&vq->last_used_idx_res,
-                                                       res_base_idx,
-                                                       res_cur_idx);
-               } while (success == 0);
+               if (reserve_avail_buf(vq, pkt_len, &start, &end) < 0)
+                       break;

                entry_success = copy_from_mbuf_to_vring(dev, queue_id,
-                       res_base_idx, res_cur_idx, pkts[pkt_idx]);
+                       start, end, pkts[pkt_idx]);

                rte_compiler_barrier();

@@ -542,14 +577,13 @@ virtio_dev_merge_rx(struct virtio_net *dev, uint16_t 
queue_id,
                 * Wait until it's our turn to add our buffer
                 * to the used ring.
                 */
-               while (unlikely(vq->last_used_idx != res_base_idx))
+               while (unlikely(vq->last_used_idx != start))
                        rte_pause();

                *(volatile uint16_t *)&vq->used->idx += entry_success;
-               vq->last_used_idx = res_cur_idx;
+               vq->last_used_idx = end;
        }

-merge_rx_exit:
        if (likely(pkt_idx)) {
                /* flush used->idx update before we read avail->flags. */
                rte_mb();

Reply via email to