NT waits can optionally be made "alertable". This is a special channel for
thread wakeup that is mildly similar to SIGIO. A thread has an internal single
bit of "alerted" state, and if a thread is alerted while an alertable wait, the
wait will return a special value, consume the "alerted" state, and will not
consume any of its objects.

Alerts are implemented using events; the user-space NT emulator is expected to
create an internal ntsync event for each thread and pass that event to wait
functions.

Signed-off-by: Elizabeth Figura <zfig...@codeweavers.com>
---
 drivers/misc/ntsync.c       | 70 ++++++++++++++++++++++++++++++++-----
 include/uapi/linux/ntsync.h |  3 +-
 2 files changed, 63 insertions(+), 10 deletions(-)

diff --git a/drivers/misc/ntsync.c b/drivers/misc/ntsync.c
index 78dc405bb759..457ff28b789f 100644
--- a/drivers/misc/ntsync.c
+++ b/drivers/misc/ntsync.c
@@ -869,22 +869,29 @@ static int setup_wait(struct ntsync_device *dev,
                      const struct ntsync_wait_args *args, bool all,
                      struct ntsync_q **ret_q)
 {
+       int fds[NTSYNC_MAX_WAIT_COUNT + 1];
        const __u32 count = args->count;
-       int fds[NTSYNC_MAX_WAIT_COUNT];
        struct ntsync_q *q;
+       __u32 total_count;
        __u32 i, j;
 
-       if (args->pad[0] || args->pad[1] || (args->flags & 
~NTSYNC_WAIT_REALTIME))
+       if (args->pad || (args->flags & ~NTSYNC_WAIT_REALTIME))
                return -EINVAL;
 
        if (args->count > NTSYNC_MAX_WAIT_COUNT)
                return -EINVAL;
 
+       total_count = count;
+       if (args->alert)
+               total_count++;
+
        if (copy_from_user(fds, u64_to_user_ptr(args->objs),
                           array_size(count, sizeof(*fds))))
                return -EFAULT;
+       if (args->alert)
+               fds[count] = args->alert;
 
-       q = kmalloc(struct_size(q, entries, count), GFP_KERNEL);
+       q = kmalloc(struct_size(q, entries, total_count), GFP_KERNEL);
        if (!q)
                return -ENOMEM;
        q->task = current;
@@ -894,7 +901,7 @@ static int setup_wait(struct ntsync_device *dev,
        q->ownerdead = false;
        q->count = count;
 
-       for (i = 0; i < count; i++) {
+       for (i = 0; i < total_count; i++) {
                struct ntsync_q_entry *entry = &q->entries[i];
                struct ntsync_obj *obj = get_obj(dev, fds[i]);
 
@@ -944,10 +951,10 @@ static void try_wake_any_obj(struct ntsync_obj *obj)
 static int ntsync_wait_any(struct ntsync_device *dev, void __user *argp)
 {
        struct ntsync_wait_args args;
+       __u32 i, total_count;
        struct ntsync_q *q;
        int signaled;
        bool all;
-       __u32 i;
        int ret;
 
        if (copy_from_user(&args, argp, sizeof(args)))
@@ -957,9 +964,13 @@ static int ntsync_wait_any(struct ntsync_device *dev, void 
__user *argp)
        if (ret < 0)
                return ret;
 
+       total_count = args.count;
+       if (args.alert)
+               total_count++;
+
        /* queue ourselves */
 
-       for (i = 0; i < args.count; i++) {
+       for (i = 0; i < total_count; i++) {
                struct ntsync_q_entry *entry = &q->entries[i];
                struct ntsync_obj *obj = entry->obj;
 
@@ -968,9 +979,15 @@ static int ntsync_wait_any(struct ntsync_device *dev, void 
__user *argp)
                ntsync_unlock_obj(dev, obj, all);
        }
 
-       /* check if we are already signaled */
+       /*
+        * Check if we are already signaled.
+        *
+        * Note that the API requires that normal objects are checked before
+        * the alert event. Hence we queue the alert event last, and check
+        * objects in order.
+        */
 
-       for (i = 0; i < args.count; i++) {
+       for (i = 0; i < total_count; i++) {
                struct ntsync_obj *obj = q->entries[i].obj;
 
                if (atomic_read(&q->signaled) != -1)
@@ -987,7 +1004,7 @@ static int ntsync_wait_any(struct ntsync_device *dev, void 
__user *argp)
 
        /* and finally, unqueue */
 
-       for (i = 0; i < args.count; i++) {
+       for (i = 0; i < total_count; i++) {
                struct ntsync_q_entry *entry = &q->entries[i];
                struct ntsync_obj *obj = entry->obj;
 
@@ -1047,6 +1064,14 @@ static int ntsync_wait_all(struct ntsync_device *dev, 
void __user *argp)
                 */
                list_add_tail(&entry->node, &obj->all_waiters);
        }
+       if (args.alert) {
+               struct ntsync_q_entry *entry = &q->entries[args.count];
+               struct ntsync_obj *obj = entry->obj;
+
+               dev_lock_obj(dev, obj);
+               list_add_tail(&entry->node, &obj->any_waiters);
+               dev_unlock_obj(dev, obj);
+       }
 
        /* check if we are already signaled */
 
@@ -1054,6 +1079,21 @@ static int ntsync_wait_all(struct ntsync_device *dev, 
void __user *argp)
 
        mutex_unlock(&dev->wait_all_lock);
 
+       /*
+        * Check if the alert event is signaled, making sure to do so only
+        * after checking if the other objects are signaled.
+        */
+
+       if (args.alert) {
+               struct ntsync_obj *obj = q->entries[args.count].obj;
+
+               if (atomic_read(&q->signaled) == -1) {
+                       bool all = ntsync_lock_obj(dev, obj);
+                       try_wake_any_obj(obj);
+                       ntsync_unlock_obj(dev, obj, all);
+               }
+       }
+
        /* sleep */
 
        ret = ntsync_schedule(q, &args);
@@ -1079,6 +1119,18 @@ static int ntsync_wait_all(struct ntsync_device *dev, 
void __user *argp)
 
        mutex_unlock(&dev->wait_all_lock);
 
+       if (args.alert) {
+               struct ntsync_q_entry *entry = &q->entries[args.count];
+               struct ntsync_obj *obj = entry->obj;
+               bool all;
+
+               all = ntsync_lock_obj(dev, obj);
+               list_del(&entry->node);
+               ntsync_unlock_obj(dev, obj, all);
+
+               put_obj(obj);
+       }
+
        signaled = atomic_read(&q->signaled);
        if (signaled != -1) {
                struct ntsync_wait_args __user *user_args = argp;
diff --git a/include/uapi/linux/ntsync.h b/include/uapi/linux/ntsync.h
index b9d208a8c00f..6d06793512b1 100644
--- a/include/uapi/linux/ntsync.h
+++ b/include/uapi/linux/ntsync.h
@@ -34,7 +34,8 @@ struct ntsync_wait_args {
        __u32 index;
        __u32 flags;
        __u32 owner;
-       __u32 pad[2];
+       __u32 alert;
+       __u32 pad;
 };
 
 #define NTSYNC_MAX_WAIT_COUNT 64
-- 
2.45.2


Reply via email to