Module Name:    src
Committed By:   christos
Date:           Sat Jul 29 17:54:54 UTC 2023

Modified Files:
        src/sys/kern: sys_memfd.c
        src/sys/sys: memfd.h

Log Message:
Fix locking and offset issues pointed out by @riastradh (Theodore Preduta)


To generate a diff of this commit:
cvs rdiff -u -r1.5 -r1.6 src/sys/kern/sys_memfd.c
cvs rdiff -u -r1.3 -r1.4 src/sys/sys/memfd.h

Please note that diffs are not public domain; they are subject to the
copyright notices on the relevant files.

Modified files:

Index: src/sys/kern/sys_memfd.c
diff -u src/sys/kern/sys_memfd.c:1.5 src/sys/kern/sys_memfd.c:1.6
--- src/sys/kern/sys_memfd.c:1.5	Sat Jul 29 08:16:34 2023
+++ src/sys/kern/sys_memfd.c	Sat Jul 29 13:54:54 2023
@@ -1,4 +1,4 @@
-/*	$NetBSD: sys_memfd.c,v 1.5 2023/07/29 12:16:34 christos Exp $	*/
+/*	$NetBSD: sys_memfd.c,v 1.6 2023/07/29 17:54:54 christos Exp $	*/
 
 /*-
  * Copyright (c) 2023 The NetBSD Foundation, Inc.
@@ -30,7 +30,7 @@
  */
 
 #include <sys/cdefs.h>
-__KERNEL_RCSID(0, "$NetBSD: sys_memfd.c,v 1.5 2023/07/29 12:16:34 christos Exp $");
+__KERNEL_RCSID(0, "$NetBSD: sys_memfd.c,v 1.6 2023/07/29 17:54:54 christos Exp $");
 
 #include <sys/param.h>
 #include <sys/types.h>
@@ -60,6 +60,7 @@ static int memfd_close(file_t *);
 static int memfd_mmap(file_t *, off_t *, size_t, int, int *, int *,
     struct uvm_object **, int *);
 static int memfd_seek(file_t *, off_t, int, off_t *, int);
+static int do_memfd_truncate(file_t *, off_t);
 static int memfd_truncate(file_t *, off_t);
 
 static const struct fileops memfd_fileops = {
@@ -104,7 +105,6 @@ sys_memfd_create(struct lwp *l, const st
 	mfd = kmem_zalloc(sizeof(*mfd), KM_SLEEP);
 	mfd->mfd_size = 0;
 	mfd->mfd_uobj = uao_create(INT64_MAX - PAGE_SIZE, 0); /* same as tmpfs */
-	mutex_init(&mfd->mfd_lock, MUTEX_DEFAULT, IPL_NONE);
 
 	CTASSERT(sizeof(memfd_prefix) < NAME_MAX); /* sanity check */
 	strcpy(mfd->mfd_name, memfd_prefix);
@@ -147,8 +147,7 @@ memfd_read(file_t *fp, off_t *offp, stru
 	vsize_t todo;
 	struct memfd *mfd = fp->f_memfd;
 
-	if (offp == &fp->f_offset)
-		mutex_enter(&fp->f_lock);
+	mutex_enter(&fp->f_lock);
 
 	if (*offp < 0) {
 		error = EINVAL;
@@ -161,18 +160,19 @@ memfd_read(file_t *fp, off_t *offp, stru
 		goto leave;
 	}
 
-	uio->uio_offset = *offp;
+	if (flags & FOF_UPDATE_OFFSET)
+		*offp = uio->uio_offset;
 	todo = MIN(uio->uio_resid, mfd->mfd_size - *offp);
 	error = ubc_uiomove(mfd->mfd_uobj, uio, todo, UVM_ADV_SEQUENTIAL,
 	    UBC_READ|UBC_PARTIALOK);
 	*offp = uio->uio_offset;
 
 leave:
-	if (offp == &fp->f_offset)
-		mutex_exit(&fp->f_lock);
-
 	getnanotime(&mfd->mfd_atime);
 
+
+	mutex_exit(&fp->f_lock);
+
 	return error;
 }
 
@@ -184,11 +184,12 @@ memfd_write(file_t *fp, off_t *offp, str
 	vsize_t todo;
 	struct memfd *mfd = fp->f_memfd;
 
-	if (mfd->mfd_seals & F_SEAL_ANY_WRITE)
-		return EPERM;
+	mutex_enter(&fp->f_lock);
 
-	if (offp == &fp->f_offset)
-		mutex_enter(&fp->f_lock);
+	if (mfd->mfd_seals & F_SEAL_ANY_WRITE) {
+		error = EPERM;
+		goto leave;
+	}
 
 	if (*offp < 0) {
 		error = EINVAL;
@@ -209,20 +210,20 @@ memfd_write(file_t *fp, off_t *offp, str
 			todo = mfd->mfd_size - *offp;
 	} else if (*offp + uio->uio_resid >= mfd->mfd_size) {
 		/* Grow to accommodate the write request. */
-		error = memfd_truncate(fp, *offp + uio->uio_resid);
+		error = do_memfd_truncate(fp, *offp + uio->uio_resid);
 		if (error != 0)
 			goto leave;
 	}
 
 	error = ubc_uiomove(mfd->mfd_uobj, uio, todo, UVM_ADV_SEQUENTIAL,
 	    UBC_WRITE|UBC_PARTIALOK);
-	*offp = uio->uio_offset;
+	if (flags & FOF_UPDATE_OFFSET)
+		*offp = uio->uio_offset;
 
 	getnanotime(&mfd->mfd_mtime);
 
 leave:
-	if (offp == &fp->f_offset)
-		mutex_exit(&fp->f_lock);
+	mutex_exit(&fp->f_lock);
 
 	return error;
 }
@@ -238,14 +239,21 @@ static int
 memfd_fcntl(file_t *fp, u_int cmd, void *data)
 {
 	struct memfd *mfd = fp->f_memfd;
+	int error = 0;
 
 	switch (cmd) {
 	case F_ADD_SEALS:
-		if (mfd->mfd_seals & F_SEAL_SEAL)
-			return EPERM;
+		mutex_enter(&fp->f_lock);
 
-		if (*(int *)data & ~MFD_KNOWN_SEALS)
-		        return EINVAL;
+		if (mfd->mfd_seals & F_SEAL_SEAL) {
+		        error = EPERM;
+			goto leave_add_seals;
+		}
+
+		if (*(int *)data & ~MFD_KNOWN_SEALS) {
+		        error = EINVAL;
+			goto leave_add_seals;
+		}
 
 		/*
 		 * Can only add F_SEAL_WRITE if there are no currently
@@ -257,13 +265,21 @@ memfd_fcntl(file_t *fp, u_int cmd, void 
 		if ((mfd->mfd_seals & F_SEAL_WRITE) == 0 &&
 		    (*(int *)data & F_SEAL_WRITE) != 0 &&
 		    mfd->mfd_uobj->uo_refs > 1)
-			return EBUSY;
+		{
+			error = EBUSY;
+			goto leave_add_seals;
+		}
 
 		mfd->mfd_seals |= *(int *)data;
-		return 0;
+
+	leave_add_seals:
+		mutex_exit(&fp->f_lock);
+		return error;
 
 	case F_GET_SEALS:
+		mutex_enter(&fp->f_lock);
 		*(int *)data = mfd->mfd_seals;
+		mutex_exit(&fp->f_lock);
 		return 0;
 
 	default:
@@ -276,6 +292,8 @@ memfd_stat(file_t *fp, struct stat *st)
 {
 	struct memfd *mfd = fp->f_memfd;
 
+	mutex_enter(&fp->f_lock);
+
 	memset(st, 0, sizeof(*st));
 	st->st_uid = kauth_cred_geteuid(fp->f_cred);
 	st->st_gid = kauth_cred_getegid(fp->f_cred);
@@ -290,6 +308,8 @@ memfd_stat(file_t *fp, struct stat *st)
 	st->st_atimespec = mfd->mfd_atime;
 	st->st_mtimespec = mfd->mfd_mtime;
 
+	mutex_exit(&fp->f_lock);
+
 	return 0;
 }
 
@@ -299,7 +319,6 @@ memfd_close(file_t *fp)
 	struct memfd *mfd = fp->f_memfd;
 
 	uao_detach(mfd->mfd_uobj);
-	mutex_destroy(&mfd->mfd_lock);
 
 	kmem_free(mfd, sizeof(*mfd));
 	fp->f_memfd = NULL;
@@ -312,20 +331,29 @@ memfd_mmap(file_t *fp, off_t *offp, size
     int *advicep, struct uvm_object **uobjp, int *maxprotp)
 {
 	struct memfd *mfd = fp->f_memfd;
+	int error = 0;
 
 	/* uvm_mmap guarantees page-aligned offset and size.  */
 	KASSERT(*offp == round_page(*offp));
 	KASSERT(size == round_page(size));
 	KASSERT(size > 0);
 
-	if (*offp < 0)
-		return EINVAL;
-	if (*offp + size > mfd->mfd_size)
-		return EINVAL;
+	mutex_enter(&fp->f_lock);
+
+	if (*offp < 0) {
+		error = EINVAL;
+		goto leave;
+	}
+	if (*offp + size > mfd->mfd_size) {
+		error = EINVAL;
+		goto leave;
+	}
 
 	if ((mfd->mfd_seals & F_SEAL_ANY_WRITE) &&
-	    (prot & VM_PROT_WRITE) && (*flagsp & MAP_PRIVATE) == 0)
-		return EPERM;
+	    (prot & VM_PROT_WRITE) && (*flagsp & MAP_PRIVATE) == 0) {
+		error = EPERM;
+		goto leave;
+	}
 
 	uao_reference(fp->f_memfd->mfd_uobj);
 	*uobjp = fp->f_memfd->mfd_uobj;
@@ -333,7 +361,10 @@ memfd_mmap(file_t *fp, off_t *offp, size
 	*maxprotp = prot;
 	*advicep = UVM_ADV_RANDOM;
 
-	return 0;
+leave:
+	mutex_exit(&fp->f_lock);
+
+	return error;
 }
 
 static int
@@ -341,7 +372,9 @@ memfd_seek(file_t *fp, off_t delta, int 
     int flags)
 {
 	off_t newoff;
-	int error;
+	int error = 0;
+
+	mutex_enter(&fp->f_lock);
 
 	switch (whence) {
 	case SEEK_CUR:
@@ -358,7 +391,7 @@ memfd_seek(file_t *fp, off_t delta, int 
 
 	default:
 		error = EINVAL;
-		return error;
+		goto leave;
 	}
 
 	if (newoffp)
@@ -366,15 +399,20 @@ memfd_seek(file_t *fp, off_t delta, int 
 	if (flags & FOF_UPDATE_OFFSET)
 		fp->f_offset = newoff;
 
-	return 0;
+leave:
+	mutex_exit(&fp->f_lock);
+
+	return error;
 }
 
 static int
-memfd_truncate(file_t *fp, off_t length)
+do_memfd_truncate(file_t *fp, off_t length)
 {
 	struct memfd *mfd = fp->f_memfd;
-	int error = 0;
 	voff_t start, end;
+	int error = 0;
+
+	KASSERT(mutex_owned(&fp->f_lock));
 
 	if (length < 0)
 		return EINVAL;
@@ -386,8 +424,6 @@ memfd_truncate(file_t *fp, off_t length)
 	if ((mfd->mfd_seals & F_SEAL_GROW) && length > mfd->mfd_size)
 		return EPERM;
 
-	mutex_enter(&mfd->mfd_lock);
-
 	if (length > mfd->mfd_size)
 		ubc_zerorange(mfd->mfd_uobj, mfd->mfd_size,
 		    length - mfd->mfd_size, 0);
@@ -406,6 +442,17 @@ memfd_truncate(file_t *fp, off_t length)
 
 	getnanotime(&mfd->mfd_mtime);
 	mfd->mfd_size = length;
-	mutex_exit(&mfd->mfd_lock);
+
+	return error;
+}
+
+static int
+memfd_truncate(file_t *fp, off_t length)
+{
+	int error;
+
+	mutex_enter(&fp->f_lock);
+	error = do_memfd_truncate(fp, length);
+	mutex_exit(&fp->f_lock);
 	return error;
 }

Index: src/sys/sys/memfd.h
diff -u src/sys/sys/memfd.h:1.3 src/sys/sys/memfd.h:1.4
--- src/sys/sys/memfd.h:1.3	Sat Jul 29 10:54:02 2023
+++ src/sys/sys/memfd.h	Sat Jul 29 13:54:54 2023
@@ -1,4 +1,4 @@
-/*	$NetBSD: memfd.h,v 1.3 2023/07/29 14:54:02 riastradh Exp $	*/
+/*	$NetBSD: memfd.h,v 1.4 2023/07/29 17:54:54 christos Exp $	*/
 
 /*-
  * Copyright (c) 2023 The NetBSD Foundation, Inc.
@@ -40,7 +40,6 @@ struct memfd {
 	struct uvm_object	*mfd_uobj;
 	size_t			mfd_size;
 	int			mfd_seals;
-	kmutex_t		mfd_lock;	/* for truncate */
 
 	struct timespec		mfd_btime;
 	struct timespec		mfd_atime;

Reply via email to