diff --git a/fs/exec.c b/fs/exec.c
index 537d92c41105..59cac7c18178 100644
--- a/fs/exec.c
+++ b/fs/exec.c
@@ -1307,6 +1307,8 @@ int begin_new_exec(struct linux_binprm * bprm)
 	 */
 	force_uaccess_begin();
 
+	if (me->flags & PF_KTHREAD)
+		free_kthread_struct(me);
 	me->flags &= ~(PF_RANDOMIZE | PF_FORKNOEXEC | PF_KTHREAD |
 					PF_NOFREEZE | PF_NO_SETAFFINITY);
 	flush_thread();
diff --git a/include/linux/kthread.h b/include/linux/kthread.h
index d86a7e3b9a52..4f3433afb54b 100644
--- a/include/linux/kthread.h
+++ b/include/linux/kthread.h
@@ -33,7 +33,7 @@ struct task_struct *kthread_create_on_cpu(int (*threadfn)(void *data),
 					  unsigned int cpu,
 					  const char *namefmt);
 
-void set_kthread_struct(struct task_struct *p);
+bool set_kthread_struct(struct task_struct *p);
 
 void kthread_set_per_cpu(struct task_struct *k, int cpu);
 bool kthread_is_per_cpu(struct task_struct *k);
diff --git a/kernel/fork.c b/kernel/fork.c
index 3244cc56b697..04fa3e5d97af 100644
--- a/kernel/fork.c
+++ b/kernel/fork.c
@@ -2118,6 +2118,10 @@ static __latent_entropy struct task_struct *copy_process(
 	p->io_context = NULL;
 	audit_set_context(p, NULL);
 	cgroup_fork(p);
+	if (p->flags & PF_KTHREAD) {
+		if (!set_kthread_struct(p))
+			goto bad_fork_cleanup_threadgroup_lock;
+	}
 #ifdef CONFIG_NUMA
 	p->mempolicy = mpol_dup(p->mempolicy);
 	if (IS_ERR(p->mempolicy)) {
diff --git a/kernel/kthread.c b/kernel/kthread.c
index 4388d6694a7f..8e5f44bed027 100644
--- a/kernel/kthread.c
+++ b/kernel/kthread.c
@@ -93,20 +93,27 @@ static inline struct kthread *__to_kthread(struct task_struct *p)
 	return kthread;
 }
 
-void set_kthread_struct(struct task_struct *p)
+bool set_kthread_struct(struct task_struct *p)
 {
 	struct kthread *kthread;
 
-	if (__to_kthread(p))
-		return;
+	if (WARN_ON_ONCE(to_kthread(p)))
+		return false;
 
 	kthread = kzalloc(sizeof(*kthread), GFP_KERNEL);
+	if (!kthread)
+		return false;
+
+	init_completion(&kthread->exited);
+	init_completion(&kthread->parked);
+	p->vfork_done = &kthread->exited;
+
 	/*
 	 * We abuse ->set_child_tid to avoid the new member and because it
-	 * can't be wrongly copied by copy_process(). We also rely on fact
-	 * that the caller can't exec, so PF_KTHREAD can't be cleared.
+	 * can't be wrongly copied by copy_process().
 	 */
 	p->set_child_tid = (__force void __user *)kthread;
+	return true;
 }
 
 void free_kthread_struct(struct task_struct *k)
@@ -114,13 +121,13 @@ void free_kthread_struct(struct task_struct *k)
 	struct kthread *kthread;
 
 	/*
-	 * Can be NULL if this kthread was created by kernel_thread()
-	 * or if kmalloc() in kthread() failed.
+	 * Can be NULL if kmalloc() in set_kthread_struct() failed.
 	 */
 	kthread = to_kthread(k);
 #ifdef CONFIG_BLK_CGROUP
 	WARN_ON_ONCE(kthread && kthread->blkcg_css);
 #endif
+	k->set_child_tid = (__force void __user *)NULL;
 	kfree(kthread);
 }
 
@@ -315,7 +322,6 @@ static int kthread(void *_create)
 	struct kthread *self;
 	int ret;
 
-	set_kthread_struct(current);
 	self = to_kthread(current);
 
 	/* If user was SIGKILLed, I release the structure. */
@@ -325,17 +331,8 @@ static int kthread(void *_create)
 		kthread_exit(-EINTR);
 	}
 
-	if (!self) {
-		create->result = ERR_PTR(-ENOMEM);
-		complete(done);
-		kthread_exit(-ENOMEM);
-	}
-
 	self->threadfn = threadfn;
 	self->data = data;
-	init_completion(&self->exited);
-	init_completion(&self->parked);
-	current->vfork_done = &self->exited;
 
 	/*
 	 * The new thread inherited kthreadd's priority and CPU mask. Reset
diff --git a/kernel/sched/core.c b/kernel/sched/core.c
index 3c9b0fda64ac..0404a8c572a1 100644
--- a/kernel/sched/core.c
+++ b/kernel/sched/core.c
@@ -8599,14 +8599,6 @@ void __init init_idle(struct task_struct *idle, int cpu)
 
 	__sched_fork(0, idle);
 
-	/*
-	 * The idle task doesn't need the kthread struct to function, but it
-	 * is dressed up as a per-CPU kthread and thus needs to play the part
-	 * if we want to avoid special-casing it in code that deals with per-CPU
-	 * kthreads.
-	 */
-	set_kthread_struct(idle);
-
 	raw_spin_lock_irqsave(&idle->pi_lock, flags);
 	raw_spin_rq_lock(rq);
 
@@ -9427,6 +9419,14 @@ void __init sched_init(void)
 	mmgrab(&init_mm);
 	enter_lazy_tlb(&init_mm, current);
 
+	/*
+	 * The idle task doesn't need the kthread struct to function, but it
+	 * is dressed up as a per-CPU kthread and thus needs to play the part
+	 * if we want to avoid special-casing it in code that deals with per-CPU
+	 * kthreads.
+	 */
+	WARN_ON(set_kthread_struct(current));
+
 	/*
 	 * Make us the idle thread. Technically, schedule() should not be
 	 * called from this thread, however somewhere below it might be,