When wrong vhost-user message are passed, the connection should be shutdown.

Signed-off-by: Tetsuya Mukawa <muk...@igel.co.jp>
---
 hw/virtio/vhost-user.c | 18 +++++++++++-------
 1 file changed, 11 insertions(+), 7 deletions(-)

diff --git a/hw/virtio/vhost-user.c b/hw/virtio/vhost-user.c
index d6f2163..2215c39 100644
--- a/hw/virtio/vhost-user.c
+++ b/hw/virtio/vhost-user.c
@@ -183,6 +183,7 @@ static int vhost_user_write(struct vhost_dev *dev, 
VhostUserMsg *msg,
 static int vhost_user_call(struct vhost_dev *dev, unsigned long int request,
         void *arg)
 {
+    CharDriverState *chr = dev->opaque;
     VhostUserMsg msg;
     VhostUserRequest msg_request;
     struct vhost_vring_file *file = 0;
@@ -242,7 +243,7 @@ static int vhost_user_call(struct vhost_dev *dev, unsigned 
long int request,
         if (!fd_num) {
             error_report("Failed initializing vhost-user memory map, "
                     "consider using -object memory-backend-file share=on");
-            return -1;
+            goto close;
         }
 
         msg.size = sizeof(m.memory.nregions);
@@ -289,7 +290,7 @@ static int vhost_user_call(struct vhost_dev *dev, unsigned 
long int request,
         break;
     default:
         error_report("vhost-user trying to send unhandled ioctl");
-        return -1;
+        goto close;
         break;
     }
 
@@ -305,33 +306,36 @@ static int vhost_user_call(struct vhost_dev *dev, 
unsigned long int request,
         if (msg_request != msg.request) {
             error_report("Received unexpected msg type."
                     " Expected %d received %d", msg_request, msg.request);
-            return -1;
+            goto close;
         }
 
         switch (msg_request) {
         case VHOST_USER_GET_FEATURES:
             if (msg.size != sizeof(m.u64)) {
                 error_report("Received bad msg size.");
-                return -1;
+                goto close;
             }
             *((__u64 *) arg) = msg.u64;
             break;
         case VHOST_USER_GET_VRING_BASE:
             if (msg.size != sizeof(m.state)) {
                 error_report("Received bad msg size.");
-                return -1;
+                goto close;
             }
             msg.state.index -= dev->vq_index;
             memcpy(arg, &msg.state, sizeof(struct vhost_vring_state));
             break;
         default:
             error_report("Received unexpected msg type.");
-            return -1;
-            break;
+            goto close;
         }
     }
 
     return 0;
+
+close:
+    qemu_chr_disconnect(chr);
+    return -1;
 }
 
 static int vhost_user_init(struct vhost_dev *dev, void *opaque)
-- 
2.1.4


Reply via email to