From 3fd2f1b9d91284afb957ab9899a83279d0e09f29 Mon Sep 17 00:00:00 2001
From: Eric Dumazet <edumazet@google.com>
Date: Wed, 8 Jul 2015 14:28:28 -0700
Subject: [PATCH 1/3] inet: remove BUG_ON() in twsk_destructor()

Kernel will crash the same if one of the pointer is NULL anyway.

Signed-off-by: Eric Dumazet <edumazet@google.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
---
 include/net/timewait_sock.h | 3 ---
 1 file changed, 3 deletions(-)

diff --git a/include/net/timewait_sock.h b/include/net/timewait_sock.h
index 68f0ecad6c6e..1a47946f95ba 100644
--- a/include/net/timewait_sock.h
+++ b/include/net/timewait_sock.h
@@ -33,9 +33,6 @@ static inline int twsk_unique(struct sock *sk, struct sock *sktw, void *twp)
 
 static inline void twsk_destructor(struct sock *sk)
 {
-	BUG_ON(sk == NULL);
-	BUG_ON(sk->sk_prot == NULL);
-	BUG_ON(sk->sk_prot->twsk_prot == NULL);
 	if (sk->sk_prot->twsk_prot->twsk_destructor != NULL)
 		sk->sk_prot->twsk_prot->twsk_destructor(sk);
 }

From fc01538f9fb75572c969ca9988176ffc2a8741d6 Mon Sep 17 00:00:00 2001
From: Eric Dumazet <edumazet@google.com>
Date: Wed, 8 Jul 2015 14:28:29 -0700
Subject: [PATCH 2/3] inet: simplify timewait refcounting

timewait sockets have a complex refcounting logic.
Once we realize it should be similar to established and
syn_recv sockets, we can use sk_nulls_del_node_init_rcu()
and remove inet_twsk_unhash()

In particular, deferred inet_twsk_put() added in commit
13475a30b66cd ("tcp: connect() race with timewait reuse")
looks unecessary : When removing a timewait socket from
ehash or bhash, caller must own a reference on the socket
anyway.

Signed-off-by: Eric Dumazet <edumazet@google.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
---
 include/net/inet_hashtables.h    |  4 +--
 include/net/inet_timewait_sock.h |  6 ++---
 net/ipv4/inet_hashtables.c       | 31 ++++++++---------------
 net/ipv4/inet_timewait_sock.c    | 42 +++++---------------------------
 net/ipv6/inet6_hashtables.c      |  6 +----
 5 files changed, 21 insertions(+), 68 deletions(-)

diff --git a/include/net/inet_hashtables.h b/include/net/inet_hashtables.h
index b73c88a19dd4..b07d126694a7 100644
--- a/include/net/inet_hashtables.h
+++ b/include/net/inet_hashtables.h
@@ -205,8 +205,8 @@ void inet_put_port(struct sock *sk);
 
 void inet_hashinfo_init(struct inet_hashinfo *h);
 
-int __inet_hash_nolisten(struct sock *sk, struct inet_timewait_sock *tw);
-int __inet_hash(struct sock *sk, struct inet_timewait_sock *tw);
+void __inet_hash_nolisten(struct sock *sk, struct sock *osk);
+void __inet_hash(struct sock *sk, struct sock *osk);
 void inet_hash(struct sock *sk);
 void inet_unhash(struct sock *sk);
 
diff --git a/include/net/inet_timewait_sock.h b/include/net/inet_timewait_sock.h
index 360c4802288d..96f52a4711c8 100644
--- a/include/net/inet_timewait_sock.h
+++ b/include/net/inet_timewait_sock.h
@@ -100,10 +100,8 @@ static inline struct inet_timewait_sock *inet_twsk(const struct sock *sk)
 void inet_twsk_free(struct inet_timewait_sock *tw);
 void inet_twsk_put(struct inet_timewait_sock *tw);
 
-int inet_twsk_unhash(struct inet_timewait_sock *tw);
-
-int inet_twsk_bind_unhash(struct inet_timewait_sock *tw,
-			  struct inet_hashinfo *hashinfo);
+void inet_twsk_bind_unhash(struct inet_timewait_sock *tw,
+			   struct inet_hashinfo *hashinfo);
 
 struct inet_timewait_sock *inet_twsk_alloc(const struct sock *sk,
 					   struct inet_timewait_death_row *dr,
diff --git a/net/ipv4/inet_hashtables.c b/net/ipv4/inet_hashtables.c
index 5f9b063bbe8a..e58840330da7 100644
--- a/net/ipv4/inet_hashtables.c
+++ b/net/ipv4/inet_hashtables.c
@@ -343,7 +343,6 @@ static int __inet_check_established(struct inet_timewait_death_row *death_row,
 	struct sock *sk2;
 	const struct hlist_nulls_node *node;
 	struct inet_timewait_sock *tw = NULL;
-	int twrefcnt = 0;
 
 	spin_lock(lock);
 
@@ -371,12 +370,10 @@ static int __inet_check_established(struct inet_timewait_death_row *death_row,
 	WARN_ON(!sk_unhashed(sk));
 	__sk_nulls_add_node_rcu(sk, &head->chain);
 	if (tw) {
-		twrefcnt = inet_twsk_unhash(tw);
+		sk_nulls_del_node_init_rcu((struct sock *)tw);
 		NET_INC_STATS_BH(net, LINUX_MIB_TIMEWAITRECYCLED);
 	}
 	spin_unlock(lock);
-	if (twrefcnt)
-		inet_twsk_put(tw);
 	sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);
 
 	if (twp) {
@@ -384,7 +381,6 @@ static int __inet_check_established(struct inet_timewait_death_row *death_row,
 	} else if (tw) {
 		/* Silly. Should hash-dance instead... */
 		inet_twsk_deschedule(tw);
-
 		inet_twsk_put(tw);
 	}
 	return 0;
@@ -403,13 +399,12 @@ static u32 inet_sk_port_offset(const struct sock *sk)
 					  inet->inet_dport);
 }
 
-int __inet_hash_nolisten(struct sock *sk, struct inet_timewait_sock *tw)
+void __inet_hash_nolisten(struct sock *sk, struct sock *osk)
 {
 	struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo;
 	struct hlist_nulls_head *list;
 	struct inet_ehash_bucket *head;
 	spinlock_t *lock;
-	int twrefcnt = 0;
 
 	WARN_ON(!sk_unhashed(sk));
 
@@ -420,23 +415,22 @@ int __inet_hash_nolisten(struct sock *sk, struct inet_timewait_sock *tw)
 
 	spin_lock(lock);
 	__sk_nulls_add_node_rcu(sk, list);
-	if (tw) {
-		WARN_ON(sk->sk_hash != tw->tw_hash);
-		twrefcnt = inet_twsk_unhash(tw);
+	if (osk) {
+		WARN_ON(sk->sk_hash != osk->sk_hash);
+		sk_nulls_del_node_init_rcu(osk);
 	}
 	spin_unlock(lock);
 	sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);
-	return twrefcnt;
 }
 EXPORT_SYMBOL_GPL(__inet_hash_nolisten);
 
-int __inet_hash(struct sock *sk, struct inet_timewait_sock *tw)
+void __inet_hash(struct sock *sk, struct sock *osk)
 {
 	struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo;
 	struct inet_listen_hashbucket *ilb;
 
 	if (sk->sk_state != TCP_LISTEN)
-		return __inet_hash_nolisten(sk, tw);
+		return __inet_hash_nolisten(sk, osk);
 
 	WARN_ON(!sk_unhashed(sk));
 	ilb = &hashinfo->listening_hash[inet_sk_listen_hashfn(sk)];
@@ -445,7 +439,6 @@ int __inet_hash(struct sock *sk, struct inet_timewait_sock *tw)
 	__sk_nulls_add_node_rcu(sk, &ilb->head);
 	sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);
 	spin_unlock(&ilb->lock);
-	return 0;
 }
 EXPORT_SYMBOL(__inet_hash);
 
@@ -492,7 +485,6 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row,
 	struct inet_bind_bucket *tb;
 	int ret;
 	struct net *net = sock_net(sk);
-	int twrefcnt = 1;
 
 	if (!snum) {
 		int i, remaining, low, high, port;
@@ -560,18 +552,15 @@ ok:
 		inet_bind_hash(sk, tb, port);
 		if (sk_unhashed(sk)) {
 			inet_sk(sk)->inet_sport = htons(port);
-			twrefcnt += __inet_hash_nolisten(sk, tw);
+			__inet_hash_nolisten(sk, (struct sock *)tw);
 		}
 		if (tw)
-			twrefcnt += inet_twsk_bind_unhash(tw, hinfo);
+			inet_twsk_bind_unhash(tw, hinfo);
 		spin_unlock(&head->lock);
 
 		if (tw) {
 			inet_twsk_deschedule(tw);
-			while (twrefcnt) {
-				twrefcnt--;
-				inet_twsk_put(tw);
-			}
+			inet_twsk_put(tw);
 		}
 
 		ret = 0;
diff --git a/net/ipv4/inet_timewait_sock.c b/net/ipv4/inet_timewait_sock.c
index 2ffbd16b79e0..92cd4d50404e 100644
--- a/net/ipv4/inet_timewait_sock.c
+++ b/net/ipv4/inet_timewait_sock.c
@@ -17,28 +17,6 @@
 #include <net/ip.h>
 
 
-/**
- *	inet_twsk_unhash - unhash a timewait socket from established hash
- *	@tw: timewait socket
- *
- *	unhash a timewait socket from established hash, if hashed.
- *	ehash lock must be held by caller.
- *	Returns 1 if caller should call inet_twsk_put() after lock release.
- */
-int inet_twsk_unhash(struct inet_timewait_sock *tw)
-{
-	if (hlist_nulls_unhashed(&tw->tw_node))
-		return 0;
-
-	hlist_nulls_del_rcu(&tw->tw_node);
-	sk_nulls_node_init(&tw->tw_node);
-	/*
-	 * We cannot call inet_twsk_put() ourself under lock,
-	 * caller must call it for us.
-	 */
-	return 1;
-}
-
 /**
  *	inet_twsk_bind_unhash - unhash a timewait socket from bind hash
  *	@tw: timewait socket
@@ -48,35 +26,29 @@ int inet_twsk_unhash(struct inet_timewait_sock *tw)
  *	bind hash lock must be held by caller.
  *	Returns 1 if caller should call inet_twsk_put() after lock release.
  */
-int inet_twsk_bind_unhash(struct inet_timewait_sock *tw,
+void inet_twsk_bind_unhash(struct inet_timewait_sock *tw,
 			  struct inet_hashinfo *hashinfo)
 {
 	struct inet_bind_bucket *tb = tw->tw_tb;
 
 	if (!tb)
-		return 0;
+		return;
 
 	__hlist_del(&tw->tw_bind_node);
 	tw->tw_tb = NULL;
 	inet_bind_bucket_destroy(hashinfo->bind_bucket_cachep, tb);
-	/*
-	 * We cannot call inet_twsk_put() ourself under lock,
-	 * caller must call it for us.
-	 */
-	return 1;
+	__sock_put((struct sock *)tw);
 }
 
 /* Must be called with locally disabled BHs. */
 static void inet_twsk_kill(struct inet_timewait_sock *tw)
 {
 	struct inet_hashinfo *hashinfo = tw->tw_dr->hashinfo;
-	struct inet_bind_hashbucket *bhead;
-	int refcnt;
-	/* Unlink from established hashes. */
 	spinlock_t *lock = inet_ehash_lockp(hashinfo, tw->tw_hash);
+	struct inet_bind_hashbucket *bhead;
 
 	spin_lock(lock);
-	refcnt = inet_twsk_unhash(tw);
+	sk_nulls_del_node_init_rcu((struct sock *)tw);
 	spin_unlock(lock);
 
 	/* Disassociate with bind bucket. */
@@ -84,11 +56,9 @@ static void inet_twsk_kill(struct inet_timewait_sock *tw)
 			hashinfo->bhash_size)];
 
 	spin_lock(&bhead->lock);
-	refcnt += inet_twsk_bind_unhash(tw, hashinfo);
+	inet_twsk_bind_unhash(tw, hashinfo);
 	spin_unlock(&bhead->lock);
 
-	BUG_ON(refcnt >= atomic_read(&tw->tw_refcnt));
-	atomic_sub(refcnt, &tw->tw_refcnt);
 	atomic_dec(&tw->tw_dr->tw_count);
 	inet_twsk_put(tw);
 }
diff --git a/net/ipv6/inet6_hashtables.c b/net/ipv6/inet6_hashtables.c
index b4fd96de97e6..a237398aa2b4 100644
--- a/net/ipv6/inet6_hashtables.c
+++ b/net/ipv6/inet6_hashtables.c
@@ -207,7 +207,6 @@ static int __inet6_check_established(struct inet_timewait_death_row *death_row,
 	struct sock *sk2;
 	const struct hlist_nulls_node *node;
 	struct inet_timewait_sock *tw = NULL;
-	int twrefcnt = 0;
 
 	spin_lock(lock);
 
@@ -234,12 +233,10 @@ static int __inet6_check_established(struct inet_timewait_death_row *death_row,
 	WARN_ON(!sk_unhashed(sk));
 	__sk_nulls_add_node_rcu(sk, &head->chain);
 	if (tw) {
-		twrefcnt = inet_twsk_unhash(tw);
+		sk_nulls_del_node_init_rcu((struct sock *)tw);
 		NET_INC_STATS_BH(net, LINUX_MIB_TIMEWAITRECYCLED);
 	}
 	spin_unlock(lock);
-	if (twrefcnt)
-		inet_twsk_put(tw);
 	sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);
 
 	if (twp) {
@@ -247,7 +244,6 @@ static int __inet6_check_established(struct inet_timewait_death_row *death_row,
 	} else if (tw) {
 		/* Silly. Should hash-dance instead... */
 		inet_twsk_deschedule(tw);
-
 		inet_twsk_put(tw);
 	}
 	return 0;

From dbe7faa4045ea83a37b691b12bb02a8f86c2d2e9 Mon Sep 17 00:00:00 2001
From: Eric Dumazet <edumazet@google.com>
Date: Wed, 8 Jul 2015 14:28:30 -0700
Subject: [PATCH 3/3] inet: inet_twsk_deschedule factorization

inet_twsk_deschedule() calls are followed by inet_twsk_put().

Only particular case is in inet_twsk_purge() but there is no point
to defer the inet_twsk_put() after re-enabling BH.

Lets rename inet_twsk_deschedule() to inet_twsk_deschedule_put()
and move the inet_twsk_put() inside.

Signed-off-by: Eric Dumazet <edumazet@google.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
---
 include/net/inet_timewait_sock.h |  2 +-
 net/ipv4/inet_hashtables.c       |  9 +++------
 net/ipv4/inet_timewait_sock.c    | 13 ++++++++-----
 net/ipv4/tcp_ipv4.c              |  3 +--
 net/ipv4/tcp_minisocks.c         |  6 ++----
 net/ipv6/inet6_hashtables.c      |  3 +--
 net/ipv6/tcp_ipv6.c              |  3 +--
 net/netfilter/xt_TPROXY.c        |  6 ++----
 8 files changed, 19 insertions(+), 26 deletions(-)

diff --git a/include/net/inet_timewait_sock.h b/include/net/inet_timewait_sock.h
index 96f52a4711c8..879d6e5a973b 100644
--- a/include/net/inet_timewait_sock.h
+++ b/include/net/inet_timewait_sock.h
@@ -111,7 +111,7 @@ void __inet_twsk_hashdance(struct inet_timewait_sock *tw, struct sock *sk,
 			   struct inet_hashinfo *hashinfo);
 
 void inet_twsk_schedule(struct inet_timewait_sock *tw, const int timeo);
-void inet_twsk_deschedule(struct inet_timewait_sock *tw);
+void inet_twsk_deschedule_put(struct inet_timewait_sock *tw);
 
 void inet_twsk_purge(struct inet_hashinfo *hashinfo,
 		     struct inet_timewait_death_row *twdr, int family);
diff --git a/net/ipv4/inet_hashtables.c b/net/ipv4/inet_hashtables.c
index e58840330da7..f8b3701a6c3c 100644
--- a/net/ipv4/inet_hashtables.c
+++ b/net/ipv4/inet_hashtables.c
@@ -380,8 +380,7 @@ static int __inet_check_established(struct inet_timewait_death_row *death_row,
 		*twp = tw;
 	} else if (tw) {
 		/* Silly. Should hash-dance instead... */
-		inet_twsk_deschedule(tw);
-		inet_twsk_put(tw);
+		inet_twsk_deschedule_put(tw);
 	}
 	return 0;
 
@@ -558,10 +557,8 @@ ok:
 			inet_twsk_bind_unhash(tw, hinfo);
 		spin_unlock(&head->lock);
 
-		if (tw) {
-			inet_twsk_deschedule(tw);
-			inet_twsk_put(tw);
-		}
+		if (tw)
+			inet_twsk_deschedule_put(tw);
 
 		ret = 0;
 		goto out;
diff --git a/net/ipv4/inet_timewait_sock.c b/net/ipv4/inet_timewait_sock.c
index 92cd4d50404e..ae22cc24fbe8 100644
--- a/net/ipv4/inet_timewait_sock.c
+++ b/net/ipv4/inet_timewait_sock.c
@@ -205,13 +205,17 @@ EXPORT_SYMBOL_GPL(inet_twsk_alloc);
  * tcp_input.c to verify this.
  */
 
-/* This is for handling early-kills of TIME_WAIT sockets. */
-void inet_twsk_deschedule(struct inet_timewait_sock *tw)
+/* This is for handling early-kills of TIME_WAIT sockets.
+ * Warning : consume reference.
+ * Caller should not access tw anymore.
+ */
+void inet_twsk_deschedule_put(struct inet_timewait_sock *tw)
 {
 	if (del_timer_sync(&tw->tw_timer))
 		inet_twsk_kill(tw);
+	inet_twsk_put(tw);
 }
-EXPORT_SYMBOL(inet_twsk_deschedule);
+EXPORT_SYMBOL(inet_twsk_deschedule_put);
 
 void inet_twsk_schedule(struct inet_timewait_sock *tw, const int timeo)
 {
@@ -281,9 +285,8 @@ restart:
 
 			rcu_read_unlock();
 			local_bh_disable();
-			inet_twsk_deschedule(tw);
+			inet_twsk_deschedule_put(tw);
 			local_bh_enable();
-			inet_twsk_put(tw);
 			goto restart_rcu;
 		}
 		/* If the nulls value we got at the end of this lookup is
diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c
index d7d4c2b79cf2..486ba96ae91a 100644
--- a/net/ipv4/tcp_ipv4.c
+++ b/net/ipv4/tcp_ipv4.c
@@ -1683,8 +1683,7 @@ do_time_wait:
 							iph->daddr, th->dest,
 							inet_iif(skb));
 		if (sk2) {
-			inet_twsk_deschedule(inet_twsk(sk));
-			inet_twsk_put(inet_twsk(sk));
+			inet_twsk_deschedule_put(inet_twsk(sk));
 			sk = sk2;
 			goto process;
 		}
diff --git a/net/ipv4/tcp_minisocks.c b/net/ipv4/tcp_minisocks.c
index 4bc00cb79e60..6d8795b066ac 100644
--- a/net/ipv4/tcp_minisocks.c
+++ b/net/ipv4/tcp_minisocks.c
@@ -147,8 +147,7 @@ tcp_timewait_state_process(struct inet_timewait_sock *tw, struct sk_buff *skb,
 		if (!th->fin ||
 		    TCP_SKB_CB(skb)->end_seq != tcptw->tw_rcv_nxt + 1) {
 kill_with_rst:
-			inet_twsk_deschedule(tw);
-			inet_twsk_put(tw);
+			inet_twsk_deschedule_put(tw);
 			return TCP_TW_RST;
 		}
 
@@ -198,8 +197,7 @@ kill_with_rst:
 			 */
 			if (sysctl_tcp_rfc1337 == 0) {
 kill:
-				inet_twsk_deschedule(tw);
-				inet_twsk_put(tw);
+				inet_twsk_deschedule_put(tw);
 				return TCP_TW_SUCCESS;
 			}
 		}
diff --git a/net/ipv6/inet6_hashtables.c b/net/ipv6/inet6_hashtables.c
index a237398aa2b4..6ac8dad0138a 100644
--- a/net/ipv6/inet6_hashtables.c
+++ b/net/ipv6/inet6_hashtables.c
@@ -243,8 +243,7 @@ static int __inet6_check_established(struct inet_timewait_death_row *death_row,
 		*twp = tw;
 	} else if (tw) {
 		/* Silly. Should hash-dance instead... */
-		inet_twsk_deschedule(tw);
-		inet_twsk_put(tw);
+		inet_twsk_deschedule_put(tw);
 	}
 	return 0;
 
diff --git a/net/ipv6/tcp_ipv6.c b/net/ipv6/tcp_ipv6.c
index 6748c4277aff..d540846a1a79 100644
--- a/net/ipv6/tcp_ipv6.c
+++ b/net/ipv6/tcp_ipv6.c
@@ -1481,8 +1481,7 @@ do_time_wait:
 					    ntohs(th->dest), tcp_v6_iif(skb));
 		if (sk2) {
 			struct inet_timewait_sock *tw = inet_twsk(sk);
-			inet_twsk_deschedule(tw);
-			inet_twsk_put(tw);
+			inet_twsk_deschedule_put(tw);
 			sk = sk2;
 			tcp_v6_restore_cb(skb);
 			goto process;
diff --git a/net/netfilter/xt_TPROXY.c b/net/netfilter/xt_TPROXY.c
index cca96cec1b68..d0c96c5ae29a 100644
--- a/net/netfilter/xt_TPROXY.c
+++ b/net/netfilter/xt_TPROXY.c
@@ -272,8 +272,7 @@ tproxy_handle_time_wait4(struct sk_buff *skb, __be32 laddr, __be16 lport,
 					    hp->source, lport ? lport : hp->dest,
 					    skb->dev, NFT_LOOKUP_LISTENER);
 		if (sk2) {
-			inet_twsk_deschedule(inet_twsk(sk));
-			inet_twsk_put(inet_twsk(sk));
+			inet_twsk_deschedule_put(inet_twsk(sk));
 			sk = sk2;
 		}
 	}
@@ -437,8 +436,7 @@ tproxy_handle_time_wait6(struct sk_buff *skb, int tproto, int thoff,
 					    tgi->lport ? tgi->lport : hp->dest,
 					    skb->dev, NFT_LOOKUP_LISTENER);
 		if (sk2) {
-			inet_twsk_deschedule(inet_twsk(sk));
-			inet_twsk_put(inet_twsk(sk));
+			inet_twsk_deschedule_put(inet_twsk(sk));
 			sk = sk2;
 		}
 	}