diff --git a/include/linux/oom.h b/include/linux/oom.h
index 922dab164eb0..da60007075b5 100644
--- a/include/linux/oom.h
+++ b/include/linux/oom.h
@@ -29,8 +29,23 @@ enum oom_scan_t {
 	OOM_SCAN_SELECT,	/* always select this thread first */
 };
 
-extern void compare_swap_oom_score_adj(short old_val, short new_val);
-extern short test_set_oom_score_adj(short new_val);
+/* Thread is the potential origin of an oom condition; kill first on oom */
+#define OOM_FLAG_ORIGIN		((__force oom_flags_t)0x1)
+
+static inline void set_current_oom_origin(void)
+{
+	current->signal->oom_flags |= OOM_FLAG_ORIGIN;
+}
+
+static inline void clear_current_oom_origin(void)
+{
+	current->signal->oom_flags &= ~OOM_FLAG_ORIGIN;
+}
+
+static inline bool oom_task_origin(const struct task_struct *p)
+{
+	return !!(p->signal->oom_flags & OOM_FLAG_ORIGIN);
+}
 
 extern unsigned long oom_badness(struct task_struct *p,
 		struct mem_cgroup *memcg, const nodemask_t *nodemask,
diff --git a/include/linux/sched.h b/include/linux/sched.h
index ed30456152da..3e387df065fc 100644
--- a/include/linux/sched.h
+++ b/include/linux/sched.h
@@ -631,6 +631,7 @@ struct signal_struct {
 	struct rw_semaphore group_rwsem;
 #endif
 
+	oom_flags_t oom_flags;
 	short oom_score_adj;		/* OOM kill score adjustment */
 	short oom_score_adj_min;	/* OOM kill score adjustment min value.
 					 * Only settable by CAP_SYS_RESOURCE. */
diff --git a/include/linux/types.h b/include/linux/types.h
index 1cc0e4b9a048..4d118ba11349 100644
--- a/include/linux/types.h
+++ b/include/linux/types.h
@@ -156,6 +156,7 @@ typedef u32 dma_addr_t;
 #endif
 typedef unsigned __bitwise__ gfp_t;
 typedef unsigned __bitwise__ fmode_t;
+typedef unsigned __bitwise__ oom_flags_t;
 
 #ifdef CONFIG_PHYS_ADDR_T_64BIT
 typedef u64 phys_addr_t;
diff --git a/mm/ksm.c b/mm/ksm.c
index b4d5a9deb17f..382d930a0bf1 100644
--- a/mm/ksm.c
+++ b/mm/ksm.c
@@ -1919,12 +1919,9 @@ static ssize_t run_store(struct kobject *kobj, struct kobj_attribute *attr,
 	if (ksm_run != flags) {
 		ksm_run = flags;
 		if (flags & KSM_RUN_UNMERGE) {
-			short oom_score_adj;
-
-			oom_score_adj = test_set_oom_score_adj(OOM_SCORE_ADJ_MAX);
+			set_current_oom_origin();
 			err = unmerge_and_remove_all_rmap_items();
-			compare_swap_oom_score_adj(OOM_SCORE_ADJ_MAX,
-								oom_score_adj);
+			clear_current_oom_origin();
 			if (err) {
 				ksm_run = KSM_RUN_STOP;
 				count = err;
diff --git a/mm/oom_kill.c b/mm/oom_kill.c
index 37ab4c5ab6e8..18f1ae2b45de 100644
--- a/mm/oom_kill.c
+++ b/mm/oom_kill.c
@@ -44,48 +44,6 @@ int sysctl_oom_kill_allocating_task;
 int sysctl_oom_dump_tasks = 1;
 static DEFINE_SPINLOCK(zone_scan_lock);
 
-/*
- * compare_swap_oom_score_adj() - compare and swap current's oom_score_adj
- * @old_val: old oom_score_adj for compare
- * @new_val: new oom_score_adj for swap
- *
- * Sets the oom_score_adj value for current to @new_val iff its present value is
- * @old_val.  Usually used to reinstate a previous value to prevent racing with
- * userspacing tuning the value in the interim.
- */
-void compare_swap_oom_score_adj(short old_val, short new_val)
-{
-	struct sighand_struct *sighand = current->sighand;
-
-	spin_lock_irq(&sighand->siglock);
-	if (current->signal->oom_score_adj == old_val)
-		current->signal->oom_score_adj = new_val;
-	trace_oom_score_adj_update(current);
-	spin_unlock_irq(&sighand->siglock);
-}
-
-/**
- * test_set_oom_score_adj() - set current's oom_score_adj and return old value
- * @new_val: new oom_score_adj value
- *
- * Sets the oom_score_adj value for current to @new_val with proper
- * synchronization and returns the old value.  Usually used to temporarily
- * set a value, save the old value in the caller, and then reinstate it later.
- */
-short test_set_oom_score_adj(short new_val)
-{
-	struct sighand_struct *sighand = current->sighand;
-	int old_val;
-
-	spin_lock_irq(&sighand->siglock);
-	old_val = current->signal->oom_score_adj;
-	current->signal->oom_score_adj = new_val;
-	trace_oom_score_adj_update(current);
-	spin_unlock_irq(&sighand->siglock);
-
-	return old_val;
-}
-
 #ifdef CONFIG_NUMA
 /**
  * has_intersects_mems_allowed() - check task eligiblity for kill
@@ -310,6 +268,13 @@ enum oom_scan_t oom_scan_process_thread(struct task_struct *task,
 	if (!task->mm)
 		return OOM_SCAN_CONTINUE;
 
+	/*
+	 * If task is allocating a lot of memory and has been marked to be
+	 * killed first if it triggers an oom, then select it.
+	 */
+	if (oom_task_origin(task))
+		return OOM_SCAN_SELECT;
+
 	if (task->flags & PF_EXITING && !force_kill) {
 		/*
 		 * If this task is not being ptraced on exit, then wait for it
diff --git a/mm/swapfile.c b/mm/swapfile.c
index bb6f6a04e92d..e97a0e5aea91 100644
--- a/mm/swapfile.c
+++ b/mm/swapfile.c
@@ -1498,7 +1498,6 @@ SYSCALL_DEFINE1(swapoff, const char __user *, specialfile)
 	struct address_space *mapping;
 	struct inode *inode;
 	struct filename *pathname;
-	short oom_score_adj;
 	int i, type, prev;
 	int err;
 
@@ -1557,9 +1556,9 @@ SYSCALL_DEFINE1(swapoff, const char __user *, specialfile)
 	p->flags &= ~SWP_WRITEOK;
 	spin_unlock(&swap_lock);
 
-	oom_score_adj = test_set_oom_score_adj(OOM_SCORE_ADJ_MAX);
+	set_current_oom_origin();
 	err = try_to_unuse(type, false, 0); /* force all pages to be unused */
-	compare_swap_oom_score_adj(OOM_SCORE_ADJ_MAX, oom_score_adj);
+	clear_current_oom_origin();
 
 	if (err) {
 		/* re-insert swap space back into swap_list */