On 12/20/2016 3:43 AM, Alex Williamson wrote:
> As part of the mdev support, type1 now gets a task reference per
> vfio_dma and uses that to get an mm reference for the task while
> working on accounting.  That's correct, but it's not fast.  For some
> paths, like vfio_pin_pages_remote(), we know we're only called from
> user context, so we can restore the lighter weight calls.  In other
> cases, we're effectively already testing whether we're in the stored
> task context elsewhere, extend this vfio_lock_acct() as well.
> 
> Signed-off-by: Alex Williamson <alex.william...@redhat.com>
> ---
> 
> v2: Use (mm == current->mm) test in vfio_lock_acct() as well rather
>     than passing around is_current.  It doesn't make sense to keep it
>     in vaddr_get_pfn() and not use it elsewhere.
> 

Thanks Alex. This change looks good to me.

Reviewed by: Kirti Wankhede <kwankh...@nvidia.com>


Thanks,
Kirti

>  drivers/vfio/vfio_iommu_type1.c |   98 
> ++++++++++++++++++++-------------------
>  1 file changed, 51 insertions(+), 47 deletions(-)
> 
> diff --git a/drivers/vfio/vfio_iommu_type1.c b/drivers/vfio/vfio_iommu_type1.c
> index 9815e45..7154987 100644
> --- a/drivers/vfio/vfio_iommu_type1.c
> +++ b/drivers/vfio/vfio_iommu_type1.c
> @@ -268,28 +268,38 @@ static void vfio_lock_acct(struct task_struct *task, 
> long npage)
>  {
>       struct vwork *vwork;
>       struct mm_struct *mm;
> +     bool is_current;
>  
>       if (!npage)
>               return;
>  
> -     mm = get_task_mm(task);
> +     is_current = (task->mm == current->mm);
> +
> +     mm = is_current ? task->mm : get_task_mm(task);
>       if (!mm)
> -             return; /* process exited or nothing to do */
> +             return; /* process exited */
>  
>       if (down_write_trylock(&mm->mmap_sem)) {
>               mm->locked_vm += npage;
>               up_write(&mm->mmap_sem);
> -             mmput(mm);
> +             if (!is_current)
> +                     mmput(mm);
>               return;
>       }
>  
> +     if (is_current) {
> +             mm = get_task_mm(task);
> +             if (!mm)
> +                     return;
> +     }
> +
>       /*
>        * Couldn't get mmap_sem lock, so must setup to update
>        * mm->locked_vm later. If locked_vm were atomic, we
>        * wouldn't need this silliness
>        */
>       vwork = kmalloc(sizeof(struct vwork), GFP_KERNEL);
> -     if (!vwork) {
> +     if (WARN_ON(!vwork)) {
>               mmput(mm);
>               return;
>       }
> @@ -393,77 +403,71 @@ static int vaddr_get_pfn(struct mm_struct *mm, unsigned 
> long vaddr,
>  static long vfio_pin_pages_remote(struct vfio_dma *dma, unsigned long vaddr,
>                                 long npage, unsigned long *pfn_base)
>  {
> -     unsigned long limit;
> -     bool lock_cap = ns_capable(task_active_pid_ns(dma->task)->user_ns,
> -                                CAP_IPC_LOCK);
> -     struct mm_struct *mm;
> -     long ret, i = 0, lock_acct = 0;
> +     unsigned long limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
> +     bool lock_cap = capable(CAP_IPC_LOCK);
> +     long ret, pinned = 0, lock_acct = 0;
>       bool rsvd;
>       dma_addr_t iova = vaddr - dma->vaddr + dma->iova;
>  
> -     mm = get_task_mm(dma->task);
> -     if (!mm)
> +     /* This code path is only user initiated */
> +     if (!current->mm)
>               return -ENODEV;
>  
> -     ret = vaddr_get_pfn(mm, vaddr, dma->prot, pfn_base);
> +     ret = vaddr_get_pfn(current->mm, vaddr, dma->prot, pfn_base);
>       if (ret)
> -             goto pin_pg_remote_exit;
> +             return ret;
>  
> +     pinned++;
>       rsvd = is_invalid_reserved_pfn(*pfn_base);
> -     limit = task_rlimit(dma->task, RLIMIT_MEMLOCK) >> PAGE_SHIFT;
>  
>       /*
>        * Reserved pages aren't counted against the user, externally pinned
>        * pages are already counted against the user.
>        */
>       if (!rsvd && !vfio_find_vpfn(dma, iova)) {
> -             if (!lock_cap && mm->locked_vm + 1 > limit) {
> +             if (!lock_cap && current->mm->locked_vm + 1 > limit) {
>                       put_pfn(*pfn_base, dma->prot);
>                       pr_warn("%s: RLIMIT_MEMLOCK (%ld) exceeded\n", __func__,
>                                       limit << PAGE_SHIFT);
> -                     ret = -ENOMEM;
> -                     goto pin_pg_remote_exit;
> +                     return -ENOMEM;
>               }
>               lock_acct++;
>       }
>  
> -     i++;
> -     if (likely(!disable_hugepages)) {
> -             /* Lock all the consecutive pages from pfn_base */
> -             for (vaddr += PAGE_SIZE, iova += PAGE_SIZE; i < npage;
> -                  i++, vaddr += PAGE_SIZE, iova += PAGE_SIZE) {
> -                     unsigned long pfn = 0;
> +     if (unlikely(disable_hugepages))
> +             goto out;
>  
> -                     ret = vaddr_get_pfn(mm, vaddr, dma->prot, &pfn);
> -                     if (ret)
> -                             break;
> +     /* Lock all the consecutive pages from pfn_base */
> +     for (vaddr += PAGE_SIZE, iova += PAGE_SIZE; pinned < npage;
> +          pinned++, vaddr += PAGE_SIZE, iova += PAGE_SIZE) {
> +             unsigned long pfn = 0;
>  
> -                     if (pfn != *pfn_base + i ||
> -                         rsvd != is_invalid_reserved_pfn(pfn)) {
> +             ret = vaddr_get_pfn(current->mm, vaddr, dma->prot, &pfn);
> +             if (ret)
> +                     break;
> +
> +             if (pfn != *pfn_base + pinned ||
> +                 rsvd != is_invalid_reserved_pfn(pfn)) {
> +                     put_pfn(pfn, dma->prot);
> +                     break;
> +             }
> +
> +             if (!rsvd && !vfio_find_vpfn(dma, iova)) {
> +                     if (!lock_cap &&
> +                         current->mm->locked_vm + lock_acct + 1 > limit) {
>                               put_pfn(pfn, dma->prot);
> +                             pr_warn("%s: RLIMIT_MEMLOCK (%ld) exceeded\n",
> +                                     __func__, limit << PAGE_SHIFT);
>                               break;
>                       }
> -
> -                     if (!rsvd && !vfio_find_vpfn(dma, iova)) {
> -                             if (!lock_cap &&
> -                                 mm->locked_vm + lock_acct + 1 > limit) {
> -                                     put_pfn(pfn, dma->prot);
> -                                     pr_warn("%s: RLIMIT_MEMLOCK (%ld) "
> -                                             "exceeded\n", __func__,
> -                                             limit << PAGE_SHIFT);
> -                                     break;
> -                             }
> -                             lock_acct++;
> -                     }
> +                     lock_acct++;
>               }
>       }
>  
> -     vfio_lock_acct(dma->task, lock_acct);
> -     ret = i;
> +out:
> +     vfio_lock_acct(current, lock_acct);
>  
> -pin_pg_remote_exit:
> -     mmput(mm);
> -     return ret;
> +     return pinned;
>  }
>  
>  static long vfio_unpin_pages_remote(struct vfio_dma *dma, dma_addr_t iova,
> @@ -473,10 +477,10 @@ static long vfio_unpin_pages_remote(struct vfio_dma 
> *dma, dma_addr_t iova,
>       long unlocked = 0, locked = 0;
>       long i;
>  
> -     for (i = 0; i < npage; i++) {
> +     for (i = 0; i < npage; i++, iova += PAGE_SIZE) {
>               if (put_pfn(pfn++, dma->prot)) {
>                       unlocked++;
> -                     if (vfio_find_vpfn(dma, iova + (i << PAGE_SHIFT)))
> +                     if (vfio_find_vpfn(dma, iova))
>                               locked++;
>               }
>       }
> 

Reply via email to