This patch avoids allocation of kthread structure on a stack, and simply
uses kmalloc.  Allocation on a stack became a huge problem (with memory
corruption and all other not nice consequences) after the commit 2deb4be28
by Andy Lutomirski, which rewinds the stack on oops, thus ooopsed kthread
steps on a garbage memory while completion of task->vfork_done structure
on the following path:

       oops_end()
       rewind_stack_do_exit()
       exit_mm()
       mm_release()
       complete_vfork_done()

Also in this patch two structures 'struct kthread_create_info' and
'struct kthread' are merged into one 'struct kthread' and its freeing
is controlled by a reference counter.

The last reference on kthread is put from a task work, the callback,
which is invoked from do_exit().  The major thing is that the last
put is happens *after* completion_vfork_done() is invoked.

Signed-off-by: Roman Pen <roman.peny...@profitbricks.com>
Cc: Andy Lutomirski <l...@kernel.org>
Cc: Oleg Nesterov <o...@redhat.com>
Cc: Peter Zijlstra <pet...@infradead.org>
Cc: Thomas Gleixner <t...@linutronix.de>
Cc: Ingo Molnar <mi...@redhat.com>
Cc: Tejun Heo <t...@kernel.org>
Cc: linux-kernel@vger.kernel.org
---
v2:
  o let x86/kernel/dumpstack.c rewind a stack, but do not use a stack
    for a structure allocation.

 kernel/kthread.c | 160 +++++++++++++++++++++++++++++++------------------------
 1 file changed, 90 insertions(+), 70 deletions(-)

diff --git a/kernel/kthread.c b/kernel/kthread.c
index 4ab4c37..9ccfe06 100644
--- a/kernel/kthread.c
+++ b/kernel/kthread.c
@@ -18,14 +18,19 @@
 #include <linux/freezer.h>
 #include <linux/ptrace.h>
 #include <linux/uaccess.h>
+#include <linux/task_work.h>
 #include <trace/events/sched.h>
 
 static DEFINE_SPINLOCK(kthread_create_lock);
 static LIST_HEAD(kthread_create_list);
 struct task_struct *kthreadd_task;
 
-struct kthread_create_info
-{
+struct kthread {
+       struct list_head list;
+       unsigned long flags;
+       unsigned int cpu;
+       atomic_t refs;
+
        /* Information passed to kthread() from kthreadd. */
        int (*threadfn)(void *data);
        void *data;
@@ -33,15 +38,9 @@ struct kthread_create_info
 
        /* Result passed back to kthread_create() from kthreadd. */
        struct task_struct *result;
-       struct completion *done;
-
-       struct list_head list;
-};
 
-struct kthread {
-       unsigned long flags;
-       unsigned int cpu;
-       void *data;
+       struct callback_head put_work;
+       struct completion *started;
        struct completion parked;
        struct completion exited;
 };
@@ -69,6 +68,24 @@ static struct kthread *to_live_kthread(struct task_struct *k)
        return NULL;
 }
 
+static inline void put_kthread(struct kthread *kthread)
+{
+       if (atomic_dec_and_test(&kthread->refs))
+               kfree(kthread);
+}
+
+/**
+ * put_kthread_cb - is called from do_exit() and does likely
+ *                  the final put.
+ */
+static void put_kthread_cb(struct callback_head *work)
+{
+       struct kthread *kthread;
+
+       kthread = container_of(work, struct kthread, put_work);
+       put_kthread(kthread);
+}
+
 /**
  * kthread_should_stop - should this kthread return now?
  *
@@ -174,41 +191,36 @@ void kthread_parkme(void)
 }
 EXPORT_SYMBOL_GPL(kthread_parkme);
 
-static int kthread(void *_create)
+static int kthreadfn(void *_self)
 {
-       /* Copy data: it's on kthread's stack */
-       struct kthread_create_info *create = _create;
-       int (*threadfn)(void *data) = create->threadfn;
-       void *data = create->data;
-       struct completion *done;
-       struct kthread self;
-       int ret;
-
-       self.flags = 0;
-       self.data = data;
-       init_completion(&self.exited);
-       init_completion(&self.parked);
-       current->vfork_done = &self.exited;
-
-       /* If user was SIGKILLed, I release the structure. */
-       done = xchg(&create->done, NULL);
-       if (!done) {
-               kfree(create);
-               do_exit(-EINTR);
+       struct completion *started;
+       struct kthread *self = _self;
+       int ret = -EINTR;
+
+       /* If user was SIGKILLed, put a ref and exit silently. */
+       started = xchg(&self->started, NULL);
+       if (!started) {
+               put_kthread(self);
+               goto exit;
        }
+       /* Delegate last ref put to a task work, which will happen
+        * after 'vfork_done' completion.
+        */
+       init_task_work(&self->put_work, put_kthread_cb);
+       task_work_add(current, &self->put_work, false);
+       current->vfork_done = &self->exited;
+
        /* OK, tell user we're spawned, wait for stop or wakeup */
        __set_current_state(TASK_UNINTERRUPTIBLE);
-       create->result = current;
-       complete(done);
+       self->result = current;
+       complete(started);
        schedule();
 
-       ret = -EINTR;
-
-       if (!test_bit(KTHREAD_SHOULD_STOP, &self.flags)) {
-               __kthread_parkme(&self);
-               ret = threadfn(data);
+       if (!test_bit(KTHREAD_SHOULD_STOP, &self->flags)) {
+               __kthread_parkme(self);
+               ret = self->threadfn(self->data);
        }
-       /* we can't just return, we must preserve "self" on stack */
+exit:
        do_exit(ret);
 }
 
@@ -222,25 +234,24 @@ int tsk_fork_get_node(struct task_struct *tsk)
        return NUMA_NO_NODE;
 }
 
-static void create_kthread(struct kthread_create_info *create)
+static void create_kthread(struct kthread *kthread)
 {
+       struct completion *started;
        int pid;
 
 #ifdef CONFIG_NUMA
-       current->pref_node_fork = create->node;
+       current->pref_node_fork = kthread->node;
 #endif
        /* We want our own signal handler (we take no signals by default). */
-       pid = kernel_thread(kthread, create, CLONE_FS | CLONE_FILES | SIGCHLD);
+       pid = kernel_thread(kthreadfn, kthread,
+                           CLONE_FS | CLONE_FILES | SIGCHLD);
        if (pid < 0) {
-               /* If user was SIGKILLed, I release the structure. */
-               struct completion *done = xchg(&create->done, NULL);
-
-               if (!done) {
-                       kfree(create);
-                       return;
+               started = xchg(&kthread->started, NULL);
+               if (started) {
+                       kthread->result = ERR_PTR(pid);
+                       complete(started);
                }
-               create->result = ERR_PTR(pid);
-               complete(done);
+               put_kthread(kthread);
        }
 }
 
@@ -272,20 +283,26 @@ struct task_struct *kthread_create_on_node(int 
(*threadfn)(void *data),
                                           const char namefmt[],
                                           ...)
 {
-       DECLARE_COMPLETION_ONSTACK(done);
+       DECLARE_COMPLETION_ONSTACK(started);
        struct task_struct *task;
-       struct kthread_create_info *create = kmalloc(sizeof(*create),
-                                                    GFP_KERNEL);
+       struct kthread *kthread;
 
-       if (!create)
+       kthread = kmalloc(sizeof(*kthread), GFP_KERNEL);
+       if (!kthread)
                return ERR_PTR(-ENOMEM);
-       create->threadfn = threadfn;
-       create->data = data;
-       create->node = node;
-       create->done = &done;
+       /* One ref for us and one ref for a new kernel thread. */
+       atomic_set(&kthread->refs, 2);
+       kthread->flags = 0;
+       kthread->cpu = 0;
+       kthread->threadfn = threadfn;
+       kthread->data = data;
+       kthread->node = node;
+       kthread->started = &started;
+       init_completion(&kthread->exited);
+       init_completion(&kthread->parked);
 
        spin_lock(&kthread_create_lock);
-       list_add_tail(&create->list, &kthread_create_list);
+       list_add_tail(&kthread->list, &kthread_create_list);
        spin_unlock(&kthread_create_lock);
 
        wake_up_process(kthreadd_task);
@@ -294,21 +311,23 @@ struct task_struct *kthread_create_on_node(int 
(*threadfn)(void *data),
         * the OOM killer while kthreadd is trying to allocate memory for
         * new kernel thread.
         */
-       if (unlikely(wait_for_completion_killable(&done))) {
+       if (unlikely(wait_for_completion_killable(&started))) {
                /*
                 * If I was SIGKILLed before kthreadd (or new kernel thread)
-                * calls complete(), leave the cleanup of this structure to
-                * that thread.
+                * calls complete(), put a ref and return an error.
                 */
-               if (xchg(&create->done, NULL))
+               if (xchg(&kthread->started, NULL)) {
+                       put_kthread(kthread);
+
                        return ERR_PTR(-EINTR);
+               }
                /*
                 * kthreadd (or new kernel thread) will call complete()
                 * shortly.
                 */
-               wait_for_completion(&done);
+               wait_for_completion(&started);
        }
-       task = create->result;
+       task = kthread->result;
        if (!IS_ERR(task)) {
                static const struct sched_param param = { .sched_priority = 0 };
                va_list args;
@@ -323,7 +342,8 @@ struct task_struct *kthread_create_on_node(int 
(*threadfn)(void *data),
                sched_setscheduler_nocheck(task, SCHED_NORMAL, &param);
                set_cpus_allowed_ptr(task, cpu_all_mask);
        }
-       kfree(create);
+       put_kthread(kthread);
+
        return task;
 }
 EXPORT_SYMBOL(kthread_create_on_node);
@@ -523,14 +543,14 @@ int kthreadd(void *unused)
 
                spin_lock(&kthread_create_lock);
                while (!list_empty(&kthread_create_list)) {
-                       struct kthread_create_info *create;
+                       struct kthread *kthread;
 
-                       create = list_entry(kthread_create_list.next,
-                                           struct kthread_create_info, list);
-                       list_del_init(&create->list);
+                       kthread = list_entry(kthread_create_list.next,
+                                            struct kthread, list);
+                       list_del_init(&kthread->list);
                        spin_unlock(&kthread_create_lock);
 
-                       create_kthread(create);
+                       create_kthread(kthread);
 
                        spin_lock(&kthread_create_lock);
                }
-- 
2.9.3

Reply via email to