diff --git a/net/ipv4/af_inet.c b/net/ipv4/af_inet.c index e05cdc608850..d7ebee3c048d 100644 --- a/net/ipv4/af_inet.c +++ b/net/ipv4/af_inet.c @@ -876,7 +876,7 @@ int inet_shutdown(struct socket *sock, int how) EPOLLHUP, even on eg. unconnected UDP sockets -- RR */ /* fall through */ default: - sk->sk_shutdown |= how; + WRITE_ONCE(sk->sk_shutdown, sk->sk_shutdown | how); if (sk->sk_prot->shutdown) sk->sk_prot->shutdown(sk, how); break; diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c index e45c09977c60..8d7933989de0 100644 --- a/net/ipv4/tcp.c +++ b/net/ipv4/tcp.c @@ -505,6 +505,7 @@ __poll_t tcp_poll(struct file *file, struct socket *sock, poll_table *wait) __poll_t mask; struct sock *sk = sock->sk; const struct tcp_sock *tp = tcp_sk(sk); + u8 shutdown; int state; sock_poll_wait(file, sock, wait); @@ -547,9 +548,10 @@ __poll_t tcp_poll(struct file *file, struct socket *sock, poll_table *wait) * NOTE. Check for TCP_CLOSE is added. The goal is to prevent * blocking on fresh not-connected or disconnected socket. --ANK */ - if (sk->sk_shutdown == SHUTDOWN_MASK || state == TCP_CLOSE) + shutdown = READ_ONCE(sk->sk_shutdown); + if (shutdown == SHUTDOWN_MASK || state == TCP_CLOSE) mask |= EPOLLHUP; - if (sk->sk_shutdown & RCV_SHUTDOWN) + if (shutdown & RCV_SHUTDOWN) mask |= EPOLLIN | EPOLLRDNORM | EPOLLRDHUP; /* Connected or passive Fast Open socket? */ @@ -565,7 +567,7 @@ __poll_t tcp_poll(struct file *file, struct socket *sock, poll_table *wait) if (tcp_stream_is_readable(tp, target, sk)) mask |= EPOLLIN | EPOLLRDNORM; - if (!(sk->sk_shutdown & SEND_SHUTDOWN)) { + if (!(shutdown & SEND_SHUTDOWN)) { if (__sk_stream_is_writeable(sk, 1)) { mask |= EPOLLOUT | EPOLLWRNORM; } else { /* send SIGIO later */ @@ -2357,7 +2359,7 @@ void __tcp_close(struct sock *sk, long timeout) int data_was_unread = 0; int state; - sk->sk_shutdown = SHUTDOWN_MASK; + WRITE_ONCE(sk->sk_shutdown, SHUTDOWN_MASK); if (sk->sk_state == TCP_LISTEN) { tcp_set_state(sk, TCP_CLOSE); @@ -2629,7 +2631,7 @@ int tcp_disconnect(struct sock *sk, int flags) if (!(sk->sk_userlocks & SOCK_BINDADDR_LOCK)) inet_reset_saddr(sk); - sk->sk_shutdown = 0; + WRITE_ONCE(sk->sk_shutdown, 0); sock_reset_flag(sk, SOCK_DONE); tp->srtt_us = 0; tp->mdev_us = jiffies_to_usecs(TCP_TIMEOUT_INIT); @@ -3905,7 +3907,7 @@ void tcp_done(struct sock *sk) if (req) reqsk_fastopen_remove(sk, req, false); - sk->sk_shutdown = SHUTDOWN_MASK; + WRITE_ONCE(sk->sk_shutdown, SHUTDOWN_MASK); if (!sock_flag(sk, SOCK_DEAD)) sk->sk_state_change(sk); diff --git a/net/ipv4/tcp_input.c b/net/ipv4/tcp_input.c index 982fe464156a..61243531a7f4 100644 --- a/net/ipv4/tcp_input.c +++ b/net/ipv4/tcp_input.c @@ -4216,7 +4216,7 @@ void tcp_fin(struct sock *sk) inet_csk_schedule_ack(sk); - sk->sk_shutdown |= RCV_SHUTDOWN; + WRITE_ONCE(sk->sk_shutdown, sk->sk_shutdown | RCV_SHUTDOWN); sock_set_flag(sk, SOCK_DONE); switch (sk->sk_state) { @@ -6354,7 +6354,7 @@ int tcp_rcv_state_process(struct sock *sk, struct sk_buff *skb) break; tcp_set_state(sk, TCP_FIN_WAIT2); - sk->sk_shutdown |= SEND_SHUTDOWN; + WRITE_ONCE(sk->sk_shutdown, sk->sk_shutdown | SEND_SHUTDOWN); sk_dst_confirm(sk);