vhci_hub_control() accesses port_status array with out of bounds port
value. Fix it to reference port_status[] only with a valid rhport value
when invalid_rhport flag is true.

The invalid_rhport flag is set early on after detecting in port value
is within the bounds or not.

The following is used reproduce the problem and verify the fix:
C reproducer:   https://syzkaller.appspot.com/x/repro.c?x=14ed8ab6400000

Reported-by: syzbot+bccc1fe10b70fadc7...@syzkaller.appspotmail.com
Signed-off-by: Shuah Khan (Samsung OSG) <sh...@kernel.org>
---
 drivers/usb/usbip/vhci_hcd.c | 57 ++++++++++++++++++++++++++----------
 1 file changed, 42 insertions(+), 15 deletions(-)

diff --git a/drivers/usb/usbip/vhci_hcd.c b/drivers/usb/usbip/vhci_hcd.c
index d11f3f8dad40..1e592ec94ba4 100644
--- a/drivers/usb/usbip/vhci_hcd.c
+++ b/drivers/usb/usbip/vhci_hcd.c
@@ -318,8 +318,9 @@ static int vhci_hub_control(struct usb_hcd *hcd, u16 
typeReq, u16 wValue,
        struct vhci_hcd *vhci_hcd;
        struct vhci     *vhci;
        int             retval = 0;
-       int             rhport;
+       int             rhport = -1;
        unsigned long   flags;
+       bool invalid_rhport = false;
 
        u32 prev_port_status[VHCI_HC_PORTS];
 
@@ -334,9 +335,19 @@ static int vhci_hub_control(struct usb_hcd *hcd, u16 
typeReq, u16 wValue,
        usbip_dbg_vhci_rh("typeReq %x wValue %x wIndex %x\n", typeReq, wValue,
                          wIndex);
 
-       if (wIndex > VHCI_HC_PORTS)
-               pr_err("invalid port number %d\n", wIndex);
-       rhport = wIndex - 1;
+       /*
+        * wIndex can be 0 for some request types (typeReq). rhport is
+        * in valid range when wIndex >= 1 and < VHCI_HC_PORTS.
+        *
+        * Reference port_status[] only with valid rhport when
+        * invalid_rhport is false.
+        */
+       if (wIndex < 1 || wIndex > VHCI_HC_PORTS) {
+               invalid_rhport = true;
+               if (wIndex > VHCI_HC_PORTS)
+                       pr_err("invalid port number %d\n", wIndex);
+       } else
+               rhport = wIndex - 1;
 
        vhci_hcd = hcd_to_vhci_hcd(hcd);
        vhci = vhci_hcd->vhci;
@@ -345,8 +356,9 @@ static int vhci_hub_control(struct usb_hcd *hcd, u16 
typeReq, u16 wValue,
 
        /* store old status and compare now and old later */
        if (usbip_dbg_flag_vhci_rh) {
-               memcpy(prev_port_status, vhci_hcd->port_status,
-                       sizeof(prev_port_status));
+               if (!invalid_rhport)
+                       memcpy(prev_port_status, vhci_hcd->port_status,
+                               sizeof(prev_port_status));
        }
 
        switch (typeReq) {
@@ -354,8 +366,10 @@ static int vhci_hub_control(struct usb_hcd *hcd, u16 
typeReq, u16 wValue,
                usbip_dbg_vhci_rh(" ClearHubFeature\n");
                break;
        case ClearPortFeature:
-               if (rhport < 0)
+               if (invalid_rhport) {
+                       pr_err("invalid port number %d\n", wIndex);
                        goto error;
+               }
                switch (wValue) {
                case USB_PORT_FEAT_SUSPEND:
                        if (hcd->speed == HCD_USB3) {
@@ -415,9 +429,10 @@ static int vhci_hub_control(struct usb_hcd *hcd, u16 
typeReq, u16 wValue,
                break;
        case GetPortStatus:
                usbip_dbg_vhci_rh(" GetPortStatus port %x\n", wIndex);
-               if (wIndex < 1) {
+               if (invalid_rhport) {
                        pr_err("invalid port number %d\n", wIndex);
                        retval = -EPIPE;
+                       goto error;
                }
 
                /* we do not care about resume. */
@@ -513,16 +528,20 @@ static int vhci_hub_control(struct usb_hcd *hcd, u16 
typeReq, u16 wValue,
                                goto error;
                        }
 
-                       if (rhport < 0)
+                       if (invalid_rhport) {
+                               pr_err("invalid port number %d\n", wIndex);
                                goto error;
+                       }
 
                        vhci_hcd->port_status[rhport] |= USB_PORT_STAT_SUSPEND;
                        break;
                case USB_PORT_FEAT_POWER:
                        usbip_dbg_vhci_rh(
                                " SetPortFeature: USB_PORT_FEAT_POWER\n");
-                       if (rhport < 0)
+                       if (invalid_rhport) {
+                               pr_err("invalid port number %d\n", wIndex);
                                goto error;
+                       }
                        if (hcd->speed == HCD_USB3)
                                vhci_hcd->port_status[rhport] |= 
USB_SS_PORT_STAT_POWER;
                        else
@@ -531,8 +550,10 @@ static int vhci_hub_control(struct usb_hcd *hcd, u16 
typeReq, u16 wValue,
                case USB_PORT_FEAT_BH_PORT_RESET:
                        usbip_dbg_vhci_rh(
                                " SetPortFeature: 
USB_PORT_FEAT_BH_PORT_RESET\n");
-                       if (rhport < 0)
+                       if (invalid_rhport) {
+                               pr_err("invalid port number %d\n", wIndex);
                                goto error;
+                       }
                        /* Applicable only for USB3.0 hub */
                        if (hcd->speed != HCD_USB3) {
                                pr_err("USB_PORT_FEAT_BH_PORT_RESET req not "
@@ -543,8 +564,10 @@ static int vhci_hub_control(struct usb_hcd *hcd, u16 
typeReq, u16 wValue,
                case USB_PORT_FEAT_RESET:
                        usbip_dbg_vhci_rh(
                                " SetPortFeature: USB_PORT_FEAT_RESET\n");
-                       if (rhport < 0)
+                       if (invalid_rhport) {
+                               pr_err("invalid port number %d\n", wIndex);
                                goto error;
+                       }
                        /* if it's already enabled, disable */
                        if (hcd->speed == HCD_USB3) {
                                vhci_hcd->port_status[rhport] = 0;
@@ -565,8 +588,10 @@ static int vhci_hub_control(struct usb_hcd *hcd, u16 
typeReq, u16 wValue,
                default:
                        usbip_dbg_vhci_rh(" SetPortFeature: default %d\n",
                                          wValue);
-                       if (rhport < 0)
+                       if (invalid_rhport) {
+                               pr_err("invalid port number %d\n", wIndex);
                                goto error;
+                       }
                        if (hcd->speed == HCD_USB3) {
                                if ((vhci_hcd->port_status[rhport] &
                                     USB_SS_PORT_STAT_POWER) != 0) {
@@ -608,7 +633,7 @@ static int vhci_hub_control(struct usb_hcd *hcd, u16 
typeReq, u16 wValue,
        if (usbip_dbg_flag_vhci_rh) {
                pr_debug("port %d\n", rhport);
                /* Only dump valid port status */
-               if (rhport >= 0) {
+               if (!invalid_rhport) {
                        dump_port_status_diff(prev_port_status[rhport],
                                              vhci_hcd->port_status[rhport],
                                              hcd->speed == HCD_USB3);
@@ -618,8 +643,10 @@ static int vhci_hub_control(struct usb_hcd *hcd, u16 
typeReq, u16 wValue,
 
        spin_unlock_irqrestore(&vhci->lock, flags);
 
-       if ((vhci_hcd->port_status[rhport] & PORT_C_MASK) != 0)
+       if (!invalid_rhport &&
+           (vhci_hcd->port_status[rhport] & PORT_C_MASK) != 0) {
                usb_hcd_poll_rh_status(hcd);
+       }
 
        return retval;
 }
-- 
2.17.0

Reply via email to