diff --git a/mm/ksm.c b/mm/ksm.c
index e87dec7ba950..9f182f99ff5e 100644
--- a/mm/ksm.c
+++ b/mm/ksm.c
@@ -1177,8 +1177,18 @@ again:
 		cond_resched();
 		stable_node = rb_entry(*new, struct stable_node, node);
 		tree_page = get_ksm_page(stable_node, false);
-		if (!tree_page)
-			return NULL;
+		if (!tree_page) {
+			/*
+			 * If we walked over a stale stable_node,
+			 * get_ksm_page() will call rb_erase() and it
+			 * may rebalance the tree from under us. So
+			 * restart the search from scratch. Returning
+			 * NULL would be safe too, but we'd generate
+			 * false negative insertions just because some
+			 * stable_node was stale.
+			 */
+			goto again;
+		}
 
 		ret = memcmp_pages(page, tree_page);
 		put_page(tree_page);
@@ -1254,12 +1264,14 @@ static struct stable_node *stable_tree_insert(struct page *kpage)
 	unsigned long kpfn;
 	struct rb_root *root;
 	struct rb_node **new;
-	struct rb_node *parent = NULL;
+	struct rb_node *parent;
 	struct stable_node *stable_node;
 
 	kpfn = page_to_pfn(kpage);
 	nid = get_kpfn_nid(kpfn);
 	root = root_stable_tree + nid;
+again:
+	parent = NULL;
 	new = &root->rb_node;
 
 	while (*new) {
@@ -1269,8 +1281,18 @@ static struct stable_node *stable_tree_insert(struct page *kpage)
 		cond_resched();
 		stable_node = rb_entry(*new, struct stable_node, node);
 		tree_page = get_ksm_page(stable_node, false);
-		if (!tree_page)
-			return NULL;
+		if (!tree_page) {
+			/*
+			 * If we walked over a stale stable_node,
+			 * get_ksm_page() will call rb_erase() and it
+			 * may rebalance the tree from under us. So
+			 * restart the search from scratch. Returning
+			 * NULL would be safe too, but we'd generate
+			 * false negative insertions just because some
+			 * stable_node was stale.
+			 */
+			goto again;
+		}
 
 		ret = memcmp_pages(kpage, tree_page);
 		put_page(tree_page);