diff --git a/include/net/inet_common.h b/include/net/inet_common.h index b86b8e21de7f..f50a644d87a9 100644 --- a/include/net/inet_common.h +++ b/include/net/inet_common.h @@ -40,8 +40,10 @@ int inet_recvmsg(struct socket *sock, struct msghdr *msg, size_t size, int flags); int inet_shutdown(struct socket *sock, int how); int inet_listen(struct socket *sock, int backlog); +int __inet_listen_sk(struct sock *sk, int backlog); void inet_sock_destruct(struct sock *sk); int inet_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len); +int inet_bind_sk(struct sock *sk, struct sockaddr *uaddr, int addr_len); /* Don't allocate port at this moment, defer to connect. */ #define BIND_FORCE_ADDRESS_NO_PORT (1 << 0) /* Grab and release socket lock. */ diff --git a/include/net/ipv6.h b/include/net/ipv6.h index 2acc4c808d45..22643ffc2df8 100644 --- a/include/net/ipv6.h +++ b/include/net/ipv6.h @@ -1216,6 +1216,7 @@ void inet6_cleanup_sock(struct sock *sk); void inet6_sock_destruct(struct sock *sk); int inet6_release(struct socket *sock); int inet6_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len); +int inet6_bind_sk(struct sock *sk, struct sockaddr *uaddr, int addr_len); int inet6_getname(struct socket *sock, struct sockaddr *uaddr, int peer); int inet6_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg); diff --git a/net/ipv4/af_inet.c b/net/ipv4/af_inet.c index 9b2ca2fcc5a1..c59da65f19d2 100644 --- a/net/ipv4/af_inet.c +++ b/net/ipv4/af_inet.c @@ -187,24 +187,13 @@ static int inet_autobind(struct sock *sk) return 0; } -/* - * Move a socket into listening state. - */ -int inet_listen(struct socket *sock, int backlog) +int __inet_listen_sk(struct sock *sk, int backlog) { - struct sock *sk = sock->sk; - unsigned char old_state; + unsigned char old_state = sk->sk_state; int err, tcp_fastopen; - lock_sock(sk); - - err = -EINVAL; - if (sock->state != SS_UNCONNECTED || sock->type != SOCK_STREAM) - goto out; - - old_state = sk->sk_state; if (!((1 << old_state) & (TCPF_CLOSE | TCPF_LISTEN))) - goto out; + return -EINVAL; WRITE_ONCE(sk->sk_max_ack_backlog, backlog); /* Really, if the socket is already in listen state @@ -227,10 +216,27 @@ int inet_listen(struct socket *sock, int backlog) err = inet_csk_listen_start(sk); if (err) - goto out; + return err; + tcp_call_bpf(sk, BPF_SOCK_OPS_TCP_LISTEN_CB, 0, NULL); } - err = 0; + return 0; +} + +/* + * Move a socket into listening state. + */ +int inet_listen(struct socket *sock, int backlog) +{ + struct sock *sk = sock->sk; + int err = -EINVAL; + + lock_sock(sk); + + if (sock->state != SS_UNCONNECTED || sock->type != SOCK_STREAM) + goto out; + + err = __inet_listen_sk(sk, backlog); out: release_sock(sk); @@ -431,9 +437,8 @@ int inet_release(struct socket *sock) } EXPORT_SYMBOL(inet_release); -int inet_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) +int inet_bind_sk(struct sock *sk, struct sockaddr *uaddr, int addr_len) { - struct sock *sk = sock->sk; u32 flags = BIND_WITH_LOCK; int err; @@ -454,6 +459,11 @@ int inet_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) return __inet_bind(sk, uaddr, addr_len, flags); } + +int inet_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) +{ + return inet_bind_sk(sock->sk, uaddr, addr_len); +} EXPORT_SYMBOL(inet_bind); int __inet_bind(struct sock *sk, struct sockaddr *uaddr, int addr_len, diff --git a/net/ipv6/af_inet6.c b/net/ipv6/af_inet6.c index 9f9c4b838664..3ec0359d5c1f 100644 --- a/net/ipv6/af_inet6.c +++ b/net/ipv6/af_inet6.c @@ -435,10 +435,8 @@ out_unlock: goto out; } -/* bind for INET6 API */ -int inet6_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) +int inet6_bind_sk(struct sock *sk, struct sockaddr *uaddr, int addr_len) { - struct sock *sk = sock->sk; u32 flags = BIND_WITH_LOCK; const struct proto *prot; int err = 0; @@ -462,6 +460,12 @@ int inet6_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) return __inet6_bind(sk, uaddr, addr_len, flags); } + +/* bind for INET6 API */ +int inet6_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) +{ + return inet6_bind_sk(sock->sk, uaddr, addr_len); +} EXPORT_SYMBOL(inet6_bind); int inet6_release(struct socket *sock) diff --git a/net/mptcp/pm_netlink.c b/net/mptcp/pm_netlink.c index 5692daf57a4d..c75d9d88a053 100644 --- a/net/mptcp/pm_netlink.c +++ b/net/mptcp/pm_netlink.c @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -1005,8 +1006,7 @@ static int mptcp_pm_nl_create_listen_socket(struct sock *sk, bool is_ipv6 = sk->sk_family == AF_INET6; int addrlen = sizeof(struct sockaddr_in); struct sockaddr_storage addr; - struct socket *ssock; - struct sock *newsk; + struct sock *newsk, *ssk; int backlog = 1024; int err; @@ -1032,28 +1032,32 @@ static int mptcp_pm_nl_create_listen_socket(struct sock *sk, &mptcp_keys[is_ipv6]); lock_sock(newsk); - ssock = __mptcp_nmpc_socket(mptcp_sk(newsk)); + ssk = __mptcp_nmpc_sk(mptcp_sk(newsk)); release_sock(newsk); - if (IS_ERR(ssock)) - return PTR_ERR(ssock); + if (IS_ERR(ssk)) + return PTR_ERR(ssk); mptcp_info2sockaddr(&entry->addr, &addr, entry->addr.family); #if IS_ENABLED(CONFIG_MPTCP_IPV6) if (entry->addr.family == AF_INET6) addrlen = sizeof(struct sockaddr_in6); #endif - err = kernel_bind(ssock, (struct sockaddr *)&addr, addrlen); + if (ssk->sk_family == AF_INET) + err = inet_bind_sk(ssk, (struct sockaddr *)&addr, addrlen); +#if IS_ENABLED(CONFIG_MPTCP_IPV6) + else if (ssk->sk_family == AF_INET6) + err = inet6_bind_sk(ssk, (struct sockaddr *)&addr, addrlen); +#endif if (err) return err; inet_sk_state_store(newsk, TCP_LISTEN); - err = kernel_listen(ssock, backlog); - if (err) - return err; - - mptcp_event_pm_listener(ssock->sk, MPTCP_EVENT_LISTENER_CREATED); - - return 0; + lock_sock(ssk); + err = __inet_listen_sk(ssk, backlog); + if (!err) + mptcp_event_pm_listener(ssk, MPTCP_EVENT_LISTENER_CREATED); + release_sock(ssk); + return err; } int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct mptcp_addr_info *skc) diff --git a/net/mptcp/protocol.c b/net/mptcp/protocol.c index 48e649fe2360..6ea0a1da8068 100644 --- a/net/mptcp/protocol.c +++ b/net/mptcp/protocol.c @@ -92,7 +92,6 @@ static int __mptcp_socket_create(struct mptcp_sock *msk) msk->scaling_ratio = tcp_sk(ssock->sk)->scaling_ratio; WRITE_ONCE(msk->first, ssock->sk); - WRITE_ONCE(msk->subflow, ssock); subflow = mptcp_subflow_ctx(ssock->sk); list_add(&subflow->node, &msk->conn_list); sock_hold(ssock->sk); @@ -102,6 +101,7 @@ static int __mptcp_socket_create(struct mptcp_sock *msk) /* This is the first subflow, always with id 0 */ subflow->local_id_valid = 1; mptcp_sock_graft(msk->first, sk->sk_socket); + iput(SOCK_INODE(ssock)); return 0; } @@ -109,7 +109,7 @@ static int __mptcp_socket_create(struct mptcp_sock *msk) /* If the MPC handshake is not started, returns the first subflow, * eventually allocating it. */ -struct socket *__mptcp_nmpc_socket(struct mptcp_sock *msk) +struct sock *__mptcp_nmpc_sk(struct mptcp_sock *msk) { struct sock *sk = (struct sock *)msk; int ret; @@ -117,10 +117,7 @@ struct socket *__mptcp_nmpc_socket(struct mptcp_sock *msk) if (!((1 << sk->sk_state) & (TCPF_CLOSE | TCPF_LISTEN))) return ERR_PTR(-EINVAL); - if (!msk->subflow) { - if (msk->first) - return ERR_PTR(-EINVAL); - + if (!msk->first) { ret = __mptcp_socket_create(msk); if (ret) return ERR_PTR(ret); @@ -128,7 +125,7 @@ struct socket *__mptcp_nmpc_socket(struct mptcp_sock *msk) mptcp_sockopt_sync(msk, msk->first); } - return msk->subflow; + return msk->first; } static void mptcp_drop(struct sock *sk, struct sk_buff *skb) @@ -1643,7 +1640,6 @@ static int mptcp_sendmsg_fastopen(struct sock *sk, struct msghdr *msg, { unsigned int saved_flags = msg->msg_flags; struct mptcp_sock *msk = mptcp_sk(sk); - struct socket *ssock; struct sock *ssk; int ret; @@ -1654,9 +1650,9 @@ static int mptcp_sendmsg_fastopen(struct sock *sk, struct msghdr *msg, * fastopen attempt, no need to check for additional subflow status. */ if (msg->msg_flags & MSG_FASTOPEN) { - ssock = __mptcp_nmpc_socket(msk); - if (IS_ERR(ssock)) - return PTR_ERR(ssock); + ssk = __mptcp_nmpc_sk(msk); + if (IS_ERR(ssk)) + return PTR_ERR(ssk); } if (!msk->first) return -EINVAL; @@ -2242,14 +2238,6 @@ static struct sock *mptcp_subflow_get_retrans(struct mptcp_sock *msk) return min_stale_count > 1 ? backup : NULL; } -static void mptcp_dispose_initial_subflow(struct mptcp_sock *msk) -{ - if (msk->subflow) { - iput(SOCK_INODE(msk->subflow)); - WRITE_ONCE(msk->subflow, NULL); - } -} - bool __mptcp_retransmit_pending_data(struct sock *sk) { struct mptcp_data_frag *cur, *rtx_head; @@ -2328,7 +2316,7 @@ static void __mptcp_close_ssk(struct sock *sk, struct sock *ssk, goto out_release; } - dispose_it = !msk->subflow || ssk != msk->subflow->sk; + dispose_it = msk->free_first || ssk != msk->first; if (dispose_it) list_del(&subflow->node); @@ -2349,7 +2337,6 @@ static void __mptcp_close_ssk(struct sock *sk, struct sock *ssk, * disconnect should never fail */ WARN_ON_ONCE(tcp_disconnect(ssk, 0)); - msk->subflow->state = SS_UNCONNECTED; mptcp_subflow_ctx_reset(subflow); release_sock(ssk); @@ -2662,7 +2649,7 @@ unlock: sock_put(sk); } -static int __mptcp_init_sock(struct sock *sk) +static void __mptcp_init_sock(struct sock *sk) { struct mptcp_sock *msk = mptcp_sk(sk); @@ -2689,8 +2676,6 @@ static int __mptcp_init_sock(struct sock *sk) /* re-use the csk retrans timer for MPTCP-level retrans */ timer_setup(&msk->sk.icsk_retransmit_timer, mptcp_retransmit_timer, 0); timer_setup(&sk->sk_timer, mptcp_timeout_timer, 0); - - return 0; } static void mptcp_ca_reset(struct sock *sk) @@ -2708,11 +2693,8 @@ static void mptcp_ca_reset(struct sock *sk) static int mptcp_init_sock(struct sock *sk) { struct net *net = sock_net(sk); - int ret; - ret = __mptcp_init_sock(sk); - if (ret) - return ret; + __mptcp_init_sock(sk); if (!mptcp_is_enabled(net)) return -ENOPROTOOPT; @@ -3110,7 +3092,6 @@ struct sock *mptcp_sk_clone_init(const struct sock *sk, msk = mptcp_sk(nsk); msk->local_key = subflow_req->local_key; msk->token = subflow_req->token; - WRITE_ONCE(msk->subflow, NULL); msk->in_accept_queue = 1; WRITE_ONCE(msk->fully_established, false); if (mp_opt->suboptions & OPTION_MPTCP_CSUMREQD) @@ -3174,25 +3155,17 @@ void mptcp_rcv_space_init(struct mptcp_sock *msk, const struct sock *ssk) WRITE_ONCE(msk->wnd_end, msk->snd_nxt + tcp_sk(ssk)->snd_wnd); } -static struct sock *mptcp_accept(struct sock *sk, int flags, int *err, +static struct sock *mptcp_accept(struct sock *ssk, int flags, int *err, bool kern) { - struct mptcp_sock *msk = mptcp_sk(sk); - struct socket *listener; struct sock *newsk; - listener = READ_ONCE(msk->subflow); - if (WARN_ON_ONCE(!listener)) { - *err = -EINVAL; - return NULL; - } - - pr_debug("msk=%p, listener=%p", msk, mptcp_subflow_ctx(listener->sk)); - newsk = inet_csk_accept(listener->sk, flags, err, kern); + pr_debug("ssk=%p, listener=%p", ssk, mptcp_subflow_ctx(ssk)); + newsk = inet_csk_accept(ssk, flags, err, kern); if (!newsk) return NULL; - pr_debug("msk=%p, subflow is mptcp=%d", msk, sk_is_mptcp(newsk)); + pr_debug("newsk=%p, subflow is mptcp=%d", newsk, sk_is_mptcp(newsk)); if (sk_is_mptcp(newsk)) { struct mptcp_subflow_context *subflow; struct sock *new_mptcp_sock; @@ -3209,9 +3182,9 @@ static struct sock *mptcp_accept(struct sock *sk, int flags, int *err, } newsk = new_mptcp_sock; - MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_MPCAPABLEPASSIVEACK); + MPTCP_INC_STATS(sock_net(ssk), MPTCP_MIB_MPCAPABLEPASSIVEACK); } else { - MPTCP_INC_STATS(sock_net(sk), + MPTCP_INC_STATS(sock_net(ssk), MPTCP_MIB_MPCAPABLEPASSIVEFALLBACK); } @@ -3252,10 +3225,8 @@ static void mptcp_destroy(struct sock *sk) { struct mptcp_sock *msk = mptcp_sk(sk); - /* clears msk->subflow, allowing the following to close - * even the initial subflow - */ - mptcp_dispose_initial_subflow(msk); + /* allow the following to close even the initial subflow */ + msk->free_first = 1; mptcp_destroy_common(msk, 0); sk_sockets_allocated_dec(sk); } @@ -3405,14 +3376,12 @@ static void mptcp_unhash(struct sock *sk) static int mptcp_get_port(struct sock *sk, unsigned short snum) { struct mptcp_sock *msk = mptcp_sk(sk); - struct socket *ssock; - ssock = msk->subflow; - pr_debug("msk=%p, subflow=%p", msk, ssock); - if (WARN_ON_ONCE(!ssock)) + pr_debug("msk=%p, ssk=%p", msk, msk->first); + if (WARN_ON_ONCE(!msk->first)) return -EINVAL; - return inet_csk_get_port(ssock->sk, snum); + return inet_csk_get_port(msk->first, snum); } void mptcp_finish_connect(struct sock *ssk) @@ -3587,25 +3556,24 @@ static int mptcp_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len) { struct mptcp_subflow_context *subflow; struct mptcp_sock *msk = mptcp_sk(sk); - struct socket *ssock; int err = -EINVAL; + struct sock *ssk; - ssock = __mptcp_nmpc_socket(msk); - if (IS_ERR(ssock)) - return PTR_ERR(ssock); + ssk = __mptcp_nmpc_sk(msk); + if (IS_ERR(ssk)) + return PTR_ERR(ssk); - mptcp_token_destroy(msk); inet_sk_state_store(sk, TCP_SYN_SENT); - subflow = mptcp_subflow_ctx(ssock->sk); + subflow = mptcp_subflow_ctx(ssk); #ifdef CONFIG_TCP_MD5SIG /* no MPTCP if MD5SIG is enabled on this socket or we may run out of * TCP option space. */ - if (rcu_access_pointer(tcp_sk(ssock->sk)->md5sig_info)) + if (rcu_access_pointer(tcp_sk(ssk)->md5sig_info)) mptcp_subflow_early_fallback(msk, subflow); #endif - if (subflow->request_mptcp && mptcp_token_new_connect(ssock->sk)) { - MPTCP_INC_STATS(sock_net(ssock->sk), MPTCP_MIB_TOKENFALLBACKINIT); + if (subflow->request_mptcp && mptcp_token_new_connect(ssk)) { + MPTCP_INC_STATS(sock_net(ssk), MPTCP_MIB_TOKENFALLBACKINIT); mptcp_subflow_early_fallback(msk, subflow); } if (likely(!__mptcp_check_fallback(msk))) @@ -3614,25 +3582,42 @@ static int mptcp_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len) /* if reaching here via the fastopen/sendmsg path, the caller already * acquired the subflow socket lock, too. */ - if (msk->fastopening) - err = __inet_stream_connect(ssock, uaddr, addr_len, O_NONBLOCK, 1); - else - err = inet_stream_connect(ssock, uaddr, addr_len, O_NONBLOCK); - inet_sk(sk)->defer_connect = inet_sk(ssock->sk)->defer_connect; + if (!msk->fastopening) + lock_sock(ssk); + + /* the following mirrors closely a very small chunk of code from + * __inet_stream_connect() + */ + if (ssk->sk_state != TCP_CLOSE) + goto out; + + if (BPF_CGROUP_PRE_CONNECT_ENABLED(ssk)) { + err = ssk->sk_prot->pre_connect(ssk, uaddr, addr_len); + if (err) + goto out; + } + + err = ssk->sk_prot->connect(ssk, uaddr, addr_len); + if (err < 0) + goto out; + + inet_sk(sk)->defer_connect = inet_sk(ssk)->defer_connect; + +out: + if (!msk->fastopening) + release_sock(ssk); /* on successful connect, the msk state will be moved to established by * subflow_finish_connect() */ - if (unlikely(err && err != -EINPROGRESS)) { - inet_sk_state_store(sk, inet_sk_state_load(ssock->sk)); + if (unlikely(err)) { + /* avoid leaving a dangling token in an unconnected socket */ + mptcp_token_destroy(msk); + inet_sk_state_store(sk, TCP_CLOSE); return err; } - mptcp_copy_inaddrs(sk, ssock->sk); - - /* silence EINPROGRESS and let the caller inet_stream_connect - * handle the connection in progress - */ + mptcp_copy_inaddrs(sk, ssk); return 0; } @@ -3673,22 +3658,27 @@ static struct proto mptcp_prot = { static int mptcp_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) { struct mptcp_sock *msk = mptcp_sk(sock->sk); - struct socket *ssock; - int err; + struct sock *ssk, *sk = sock->sk; + int err = -EINVAL; - lock_sock(sock->sk); - ssock = __mptcp_nmpc_socket(msk); - if (IS_ERR(ssock)) { - err = PTR_ERR(ssock); + lock_sock(sk); + ssk = __mptcp_nmpc_sk(msk); + if (IS_ERR(ssk)) { + err = PTR_ERR(ssk); goto unlock; } - err = READ_ONCE(ssock->ops)->bind(ssock, uaddr, addr_len); + if (sk->sk_family == AF_INET) + err = inet_bind_sk(ssk, uaddr, addr_len); +#if IS_ENABLED(CONFIG_MPTCP_IPV6) + else if (sk->sk_family == AF_INET6) + err = inet6_bind_sk(ssk, uaddr, addr_len); +#endif if (!err) - mptcp_copy_inaddrs(sock->sk, ssock->sk); + mptcp_copy_inaddrs(sk, ssk); unlock: - release_sock(sock->sk); + release_sock(sk); return err; } @@ -3696,7 +3686,7 @@ static int mptcp_listen(struct socket *sock, int backlog) { struct mptcp_sock *msk = mptcp_sk(sock->sk); struct sock *sk = sock->sk; - struct socket *ssock; + struct sock *ssk; int err; pr_debug("msk=%p", msk); @@ -3707,22 +3697,24 @@ static int mptcp_listen(struct socket *sock, int backlog) if (sock->state != SS_UNCONNECTED || sock->type != SOCK_STREAM) goto unlock; - ssock = __mptcp_nmpc_socket(msk); - if (IS_ERR(ssock)) { - err = PTR_ERR(ssock); + ssk = __mptcp_nmpc_sk(msk); + if (IS_ERR(ssk)) { + err = PTR_ERR(ssk); goto unlock; } - mptcp_token_destroy(msk); inet_sk_state_store(sk, TCP_LISTEN); sock_set_flag(sk, SOCK_RCU_FREE); - err = READ_ONCE(ssock->ops)->listen(ssock, backlog); - inet_sk_state_store(sk, inet_sk_state_load(ssock->sk)); + lock_sock(ssk); + err = __inet_listen_sk(ssk, backlog); + release_sock(ssk); + inet_sk_state_store(sk, inet_sk_state_load(ssk)); + if (!err) { sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1); - mptcp_copy_inaddrs(sk, ssock->sk); - mptcp_event_pm_listener(ssock->sk, MPTCP_EVENT_LISTENER_CREATED); + mptcp_copy_inaddrs(sk, ssk); + mptcp_event_pm_listener(ssk, MPTCP_EVENT_LISTENER_CREATED); } unlock: @@ -3734,8 +3726,7 @@ static int mptcp_stream_accept(struct socket *sock, struct socket *newsock, int flags, bool kern) { struct mptcp_sock *msk = mptcp_sk(sock->sk); - struct socket *ssock; - struct sock *newsk; + struct sock *ssk, *newsk; int err; pr_debug("msk=%p", msk); @@ -3743,11 +3734,11 @@ static int mptcp_stream_accept(struct socket *sock, struct socket *newsock, /* Buggy applications can call accept on socket states other then LISTEN * but no need to allocate the first subflow just to error out. */ - ssock = READ_ONCE(msk->subflow); - if (!ssock) + ssk = READ_ONCE(msk->first); + if (!ssk) return -EINVAL; - newsk = mptcp_accept(sock->sk, flags, &err, kern); + newsk = mptcp_accept(ssk, flags, &err, kern); if (!newsk) return err; @@ -3774,11 +3765,10 @@ static int mptcp_stream_accept(struct socket *sock, struct socket *newsock, /* Do late cleanup for the first subflow as necessary. Also * deal with bad peers not doing a complete shutdown. */ - if (msk->first && - unlikely(inet_sk_state_load(msk->first) == TCP_CLOSE)) { + if (unlikely(inet_sk_state_load(msk->first) == TCP_CLOSE)) { __mptcp_close_ssk(newsk, msk->first, mptcp_subflow_ctx(msk->first), 0); - if (unlikely(list_empty(&msk->conn_list))) + if (unlikely(list_is_singular(&msk->conn_list))) inet_sk_state_store(newsk, TCP_CLOSE); } } @@ -3817,12 +3807,12 @@ static __poll_t mptcp_poll(struct file *file, struct socket *sock, state = inet_sk_state_load(sk); pr_debug("msk=%p state=%d flags=%lx", msk, state, msk->flags); if (state == TCP_LISTEN) { - struct socket *ssock = READ_ONCE(msk->subflow); + struct sock *ssk = READ_ONCE(msk->first); - if (WARN_ON_ONCE(!ssock || !ssock->sk)) + if (WARN_ON_ONCE(!ssk)) return 0; - return inet_csk_listen_poll(ssock->sk); + return inet_csk_listen_poll(ssk); } shutdown = READ_ONCE(sk->sk_shutdown); diff --git a/net/mptcp/protocol.h b/net/mptcp/protocol.h index 79fc5cdb67bc..38c7ea013361 100644 --- a/net/mptcp/protocol.h +++ b/net/mptcp/protocol.h @@ -299,7 +299,8 @@ struct mptcp_sock { cork:1, nodelay:1, fastopening:1, - in_accept_queue:1; + in_accept_queue:1, + free_first:1; struct work_struct work; struct sk_buff *ooo_last_skb; struct rb_root out_of_order_queue; @@ -308,12 +309,10 @@ struct mptcp_sock { struct list_head rtx_queue; struct mptcp_data_frag *first_pending; struct list_head join_list; - struct socket *subflow; /* outgoing connect/listener/!mp_capable - * The mptcp ops can safely dereference, using suitable - * ONCE annotation, the subflow outside the socket - * lock as such sock is freed after close(). - */ - struct sock *first; + struct sock *first; /* The mptcp ops can safely dereference, using suitable + * ONCE annotation, the subflow outside the socket + * lock as such sock is freed after close(). + */ struct mptcp_pm_data pm; struct { u32 space; /* bytes copied in last measurement window */ @@ -640,7 +639,7 @@ void __mptcp_subflow_send_ack(struct sock *ssk); void mptcp_subflow_reset(struct sock *ssk); void mptcp_subflow_queue_clean(struct sock *sk, struct sock *ssk); void mptcp_sock_graft(struct sock *sk, struct socket *parent); -struct socket *__mptcp_nmpc_socket(struct mptcp_sock *msk); +struct sock *__mptcp_nmpc_sk(struct mptcp_sock *msk); bool __mptcp_close(struct sock *sk, long timeout); void mptcp_cancel_work(struct sock *sk); void __mptcp_unaccepted_force_close(struct sock *sk); diff --git a/net/mptcp/sockopt.c b/net/mptcp/sockopt.c index a3f1fe810cc9..21bc46acbe38 100644 --- a/net/mptcp/sockopt.c +++ b/net/mptcp/sockopt.c @@ -292,7 +292,7 @@ static int mptcp_setsockopt_sol_socket(struct mptcp_sock *msk, int optname, sockptr_t optval, unsigned int optlen) { struct sock *sk = (struct sock *)msk; - struct socket *ssock; + struct sock *ssk; int ret; switch (optname) { @@ -301,22 +301,22 @@ static int mptcp_setsockopt_sol_socket(struct mptcp_sock *msk, int optname, case SO_BINDTODEVICE: case SO_BINDTOIFINDEX: lock_sock(sk); - ssock = __mptcp_nmpc_socket(msk); - if (IS_ERR(ssock)) { + ssk = __mptcp_nmpc_sk(msk); + if (IS_ERR(ssk)) { release_sock(sk); - return PTR_ERR(ssock); + return PTR_ERR(ssk); } - ret = sock_setsockopt(ssock, SOL_SOCKET, optname, optval, optlen); + ret = sk_setsockopt(ssk, SOL_SOCKET, optname, optval, optlen); if (ret == 0) { if (optname == SO_REUSEPORT) - sk->sk_reuseport = ssock->sk->sk_reuseport; + sk->sk_reuseport = ssk->sk_reuseport; else if (optname == SO_REUSEADDR) - sk->sk_reuse = ssock->sk->sk_reuse; + sk->sk_reuse = ssk->sk_reuse; else if (optname == SO_BINDTODEVICE) - sk->sk_bound_dev_if = ssock->sk->sk_bound_dev_if; + sk->sk_bound_dev_if = ssk->sk_bound_dev_if; else if (optname == SO_BINDTOIFINDEX) - sk->sk_bound_dev_if = ssock->sk->sk_bound_dev_if; + sk->sk_bound_dev_if = ssk->sk_bound_dev_if; } release_sock(sk); return ret; @@ -390,20 +390,20 @@ static int mptcp_setsockopt_v6(struct mptcp_sock *msk, int optname, { struct sock *sk = (struct sock *)msk; int ret = -EOPNOTSUPP; - struct socket *ssock; + struct sock *ssk; switch (optname) { case IPV6_V6ONLY: case IPV6_TRANSPARENT: case IPV6_FREEBIND: lock_sock(sk); - ssock = __mptcp_nmpc_socket(msk); - if (IS_ERR(ssock)) { + ssk = __mptcp_nmpc_sk(msk); + if (IS_ERR(ssk)) { release_sock(sk); - return PTR_ERR(ssock); + return PTR_ERR(ssk); } - ret = tcp_setsockopt(ssock->sk, SOL_IPV6, optname, optval, optlen); + ret = tcp_setsockopt(ssk, SOL_IPV6, optname, optval, optlen); if (ret != 0) { release_sock(sk); return ret; @@ -413,13 +413,13 @@ static int mptcp_setsockopt_v6(struct mptcp_sock *msk, int optname, switch (optname) { case IPV6_V6ONLY: - sk->sk_ipv6only = ssock->sk->sk_ipv6only; + sk->sk_ipv6only = ssk->sk_ipv6only; break; case IPV6_TRANSPARENT: - inet_sk(sk)->transparent = inet_sk(ssock->sk)->transparent; + inet_sk(sk)->transparent = inet_sk(ssk)->transparent; break; case IPV6_FREEBIND: - inet_sk(sk)->freebind = inet_sk(ssock->sk)->freebind; + inet_sk(sk)->freebind = inet_sk(ssk)->freebind; break; } @@ -685,7 +685,7 @@ static int mptcp_setsockopt_sol_ip_set_transparent(struct mptcp_sock *msk, int o { struct sock *sk = (struct sock *)msk; struct inet_sock *issk; - struct socket *ssock; + struct sock *ssk; int err; err = ip_setsockopt(sk, SOL_IP, optname, optval, optlen); @@ -694,13 +694,13 @@ static int mptcp_setsockopt_sol_ip_set_transparent(struct mptcp_sock *msk, int o lock_sock(sk); - ssock = __mptcp_nmpc_socket(msk); - if (IS_ERR(ssock)) { + ssk = __mptcp_nmpc_sk(msk); + if (IS_ERR(ssk)) { release_sock(sk); - return PTR_ERR(ssock); + return PTR_ERR(ssk); } - issk = inet_sk(ssock->sk); + issk = inet_sk(ssk); switch (optname) { case IP_FREEBIND: @@ -763,18 +763,18 @@ static int mptcp_setsockopt_first_sf_only(struct mptcp_sock *msk, int level, int sockptr_t optval, unsigned int optlen) { struct sock *sk = (struct sock *)msk; - struct socket *sock; + struct sock *ssk; int ret; /* Limit to first subflow, before the connection establishment */ lock_sock(sk); - sock = __mptcp_nmpc_socket(msk); - if (IS_ERR(sock)) { - ret = PTR_ERR(sock); + ssk = __mptcp_nmpc_sk(msk); + if (IS_ERR(ssk)) { + ret = PTR_ERR(ssk); goto unlock; } - ret = tcp_setsockopt(sock->sk, level, optname, optval, optlen); + ret = tcp_setsockopt(ssk, level, optname, optval, optlen); unlock: release_sock(sk); @@ -864,9 +864,8 @@ static int mptcp_getsockopt_first_sf_only(struct mptcp_sock *msk, int level, int char __user *optval, int __user *optlen) { struct sock *sk = (struct sock *)msk; - struct socket *ssock; - int ret; struct sock *ssk; + int ret; lock_sock(sk); ssk = msk->first; @@ -875,13 +874,13 @@ static int mptcp_getsockopt_first_sf_only(struct mptcp_sock *msk, int level, int goto out; } - ssock = __mptcp_nmpc_socket(msk); - if (IS_ERR(ssock)) { - ret = PTR_ERR(ssock); + ssk = __mptcp_nmpc_sk(msk); + if (IS_ERR(ssk)) { + ret = PTR_ERR(ssk); goto out; } - ret = tcp_getsockopt(ssock->sk, level, optname, optval, optlen); + ret = tcp_getsockopt(ssk, level, optname, optval, optlen); out: release_sock(sk);