diff --git a/include/net/sock.h b/include/net/sock.h index 78c292f15ffc..efb7c0ff30d8 100644 --- a/include/net/sock.h +++ b/include/net/sock.h @@ -378,7 +378,7 @@ struct sock { #ifdef CONFIG_XFRM struct xfrm_policy __rcu *sk_policy[2]; #endif - struct dst_entry *sk_rx_dst; + struct dst_entry __rcu *sk_rx_dst; struct dst_entry __rcu *sk_dst_cache; /* Note: 32bit hole on 64bit arches */ atomic_t sk_wmem_alloc; diff --git a/net/ipv4/af_inet.c b/net/ipv4/af_inet.c index 970a498c1166..c897d3fd457e 100644 --- a/net/ipv4/af_inet.c +++ b/net/ipv4/af_inet.c @@ -156,7 +156,7 @@ void inet_sock_destruct(struct sock *sk) kfree(rcu_dereference_protected(inet->inet_opt, 1)); dst_release(rcu_dereference_check(sk->sk_dst_cache, 1)); - dst_release(sk->sk_rx_dst); + dst_release(rcu_dereference_protected(sk->sk_rx_dst, 1)); sk_refcnt_debug_dec(sk); } EXPORT_SYMBOL(inet_sock_destruct); diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c index 4252aa0b5143..29b45910fafe 100644 --- a/net/ipv4/tcp.c +++ b/net/ipv4/tcp.c @@ -2318,8 +2318,7 @@ int tcp_disconnect(struct sock *sk, int flags) tcp_init_send_head(sk); memset(&tp->rx_opt, 0, sizeof(tp->rx_opt)); __sk_dst_reset(sk); - dst_release(sk->sk_rx_dst); - sk->sk_rx_dst = NULL; + dst_release(xchg((__force struct dst_entry **)&sk->sk_rx_dst, NULL)); tcp_saved_syn_free(tp); tp->segs_in = 0; tp->segs_out = 0; diff --git a/net/ipv4/tcp_input.c b/net/ipv4/tcp_input.c index 2029e7a36cbb..9c7f716aab44 100644 --- a/net/ipv4/tcp_input.c +++ b/net/ipv4/tcp_input.c @@ -5481,7 +5481,7 @@ void tcp_rcv_established(struct sock *sk, struct sk_buff *skb, { struct tcp_sock *tp = tcp_sk(sk); - if (unlikely(!sk->sk_rx_dst)) + if (unlikely(!rcu_access_pointer(sk->sk_rx_dst))) inet_csk(sk)->icsk_af_ops->sk_rx_dst_set(sk, skb); /* * Header prediction. diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c index 6f895694cca1..1fd1d590fa2b 100644 --- a/net/ipv4/tcp_ipv4.c +++ b/net/ipv4/tcp_ipv4.c @@ -1404,15 +1404,18 @@ int tcp_v4_do_rcv(struct sock *sk, struct sk_buff *skb) struct sock *rsk; if (sk->sk_state == TCP_ESTABLISHED) { /* Fast path */ - struct dst_entry *dst = sk->sk_rx_dst; + struct dst_entry *dst; + + dst = rcu_dereference_protected(sk->sk_rx_dst, + lockdep_sock_is_held(sk)); sock_rps_save_rxhash(sk, skb); sk_mark_napi_id(sk, skb); if (dst) { if (inet_sk(sk)->rx_dst_ifindex != skb->skb_iif || !dst->ops->check(dst, 0)) { + RCU_INIT_POINTER(sk->sk_rx_dst, NULL); dst_release(dst); - sk->sk_rx_dst = NULL; } } tcp_rcv_established(sk, skb, tcp_hdr(skb), skb->len); @@ -1489,7 +1492,7 @@ void tcp_v4_early_demux(struct sk_buff *skb) skb->sk = sk; skb->destructor = sock_edemux; if (sk_fullsock(sk)) { - struct dst_entry *dst = READ_ONCE(sk->sk_rx_dst); + struct dst_entry *dst = rcu_dereference(sk->sk_rx_dst); if (dst) dst = dst_check(dst, 0); @@ -1524,7 +1527,7 @@ bool tcp_prequeue(struct sock *sk, struct sk_buff *skb) * Instead of doing full sk_rx_dst validity here, let's perform * an optimistic check. */ - if (likely(sk->sk_rx_dst)) + if (likely(rcu_access_pointer(sk->sk_rx_dst))) skb_dst_drop(skb); else skb_dst_force_safe(skb); @@ -1818,7 +1821,7 @@ void inet_sk_rx_dst_set(struct sock *sk, const struct sk_buff *skb) struct dst_entry *dst = skb_dst(skb); if (dst && dst_hold_safe(dst)) { - sk->sk_rx_dst = dst; + rcu_assign_pointer(sk->sk_rx_dst, dst); inet_sk(sk)->rx_dst_ifindex = skb->skb_iif; } } diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c index 8770966a564b..888c6eb8f0a0 100644 --- a/net/ipv4/udp.c +++ b/net/ipv4/udp.c @@ -1630,7 +1630,7 @@ static void udp_sk_rx_dst_set(struct sock *sk, struct dst_entry *dst) struct dst_entry *old; dst_hold(dst); - old = xchg(&sk->sk_rx_dst, dst); + old = xchg((__force struct dst_entry **)&sk->sk_rx_dst, dst); dst_release(old); } @@ -1815,7 +1815,7 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable, struct dst_entry *dst = skb_dst(skb); int ret; - if (unlikely(sk->sk_rx_dst != dst)) + if (unlikely(rcu_dereference(sk->sk_rx_dst) != dst)) udp_sk_rx_dst_set(sk, dst); ret = udp_unicast_rcv_skb(sk, skb, uh); @@ -1974,7 +1974,7 @@ void udp_v4_early_demux(struct sk_buff *skb) skb->sk = sk; skb->destructor = sock_efree; - dst = READ_ONCE(sk->sk_rx_dst); + dst = rcu_dereference(sk->sk_rx_dst); if (dst) dst = dst_check(dst, 0); diff --git a/net/ipv6/tcp_ipv6.c b/net/ipv6/tcp_ipv6.c index 6b8da8d4d1cc..d8b907910cc9 100644 --- a/net/ipv6/tcp_ipv6.c +++ b/net/ipv6/tcp_ipv6.c @@ -95,7 +95,7 @@ static void inet6_sk_rx_dst_set(struct sock *sk, const struct sk_buff *skb) if (dst && dst_hold_safe(dst)) { const struct rt6_info *rt = (const struct rt6_info *)dst; - sk->sk_rx_dst = dst; + rcu_assign_pointer(sk->sk_rx_dst, dst); inet_sk(sk)->rx_dst_ifindex = skb->skb_iif; inet6_sk(sk)->rx_dst_cookie = rt6_get_cookie(rt); } @@ -1285,15 +1285,18 @@ static int tcp_v6_do_rcv(struct sock *sk, struct sk_buff *skb) opt_skb = skb_clone(skb, sk_gfp_mask(sk, GFP_ATOMIC)); if (sk->sk_state == TCP_ESTABLISHED) { /* Fast path */ - struct dst_entry *dst = sk->sk_rx_dst; + struct dst_entry *dst; + + dst = rcu_dereference_protected(sk->sk_rx_dst, + lockdep_sock_is_held(sk)); sock_rps_save_rxhash(sk, skb); sk_mark_napi_id(sk, skb); if (dst) { if (inet_sk(sk)->rx_dst_ifindex != skb->skb_iif || dst->ops->check(dst, np->rx_dst_cookie) == NULL) { + RCU_INIT_POINTER(sk->sk_rx_dst, NULL); dst_release(dst); - sk->sk_rx_dst = NULL; } } @@ -1621,7 +1624,7 @@ static void tcp_v6_early_demux(struct sk_buff *skb) skb->sk = sk; skb->destructor = sock_edemux; if (sk_fullsock(sk)) { - struct dst_entry *dst = READ_ONCE(sk->sk_rx_dst); + struct dst_entry *dst = rcu_dereference(sk->sk_rx_dst); if (dst) dst = dst_check(dst, inet6_sk(sk)->rx_dst_cookie);