diff --git a/net/mptcp/options.c b/net/mptcp/options.c index 5ded85e2c374..b30cea2fbf3f 100644 --- a/net/mptcp/options.c +++ b/net/mptcp/options.c @@ -1594,8 +1594,7 @@ mp_rst: TCPOLEN_MPTCP_PRIO, opts->backup, TCPOPT_NOP); - MPTCP_INC_STATS(sock_net((const struct sock *)tp), - MPTCP_MIB_MPPRIOTX); + MPTCP_INC_STATS(sock_net(ssk), MPTCP_MIB_MPPRIOTX); } mp_capable_done: diff --git a/net/mptcp/pm_netlink.c b/net/mptcp/pm_netlink.c index 2ea7eae43bdb..b5505b8167f9 100644 --- a/net/mptcp/pm_netlink.c +++ b/net/mptcp/pm_netlink.c @@ -1143,7 +1143,7 @@ void mptcp_pm_nl_subflow_chk_stale(const struct mptcp_sock *msk, struct sock *ss if (!tcp_rtx_and_write_queues_empty(ssk)) { subflow->stale = 1; __mptcp_retransmit_pending_data(sk); - MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_SUBFLOWSTALE); + MPTCP_INC_STATS(net, MPTCP_MIB_SUBFLOWSTALE); } unlock_sock_fast(ssk, slow); @@ -1903,8 +1903,7 @@ static int mptcp_nl_cmd_set_flags(struct sk_buff *skb, struct genl_info *info) } if (token) - return mptcp_userspace_pm_set_flags(sock_net(skb->sk), - token, &addr, &remote, bkup); + return mptcp_userspace_pm_set_flags(net, token, &addr, &remote, bkup); spin_lock_bh(&pernet->lock); entry = __lookup_addr(pernet, &addr.addr, lookup_by_id); diff --git a/net/mptcp/protocol.c b/net/mptcp/protocol.c index b7ad030dfe89..13595d6dad8c 100644 --- a/net/mptcp/protocol.c +++ b/net/mptcp/protocol.c @@ -923,9 +923,8 @@ static void mptcp_check_for_eof(struct mptcp_sock *msk) static struct sock *mptcp_subflow_recv_lookup(const struct mptcp_sock *msk) { struct mptcp_subflow_context *subflow; - struct sock *sk = (struct sock *)msk; - sock_owned_by_me(sk); + msk_owned_by_me(msk); mptcp_for_each_subflow(msk, subflow) { if (READ_ONCE(subflow->data_avail)) @@ -1408,7 +1407,7 @@ static struct sock *mptcp_subflow_get_send(struct mptcp_sock *msk) u64 linger_time; long tout = 0; - sock_owned_by_me(sk); + msk_owned_by_me(msk); if (__mptcp_check_fallback(msk)) { if (!msk->first) @@ -1890,7 +1889,7 @@ static void mptcp_rcv_space_adjust(struct mptcp_sock *msk, int copied) u32 time, advmss = 1; u64 rtt_us, mstamp; - sock_owned_by_me(sk); + msk_owned_by_me(msk); if (copied <= 0) return; @@ -2217,7 +2216,7 @@ static struct sock *mptcp_subflow_get_retrans(struct mptcp_sock *msk) struct mptcp_subflow_context *subflow; int min_stale_count = INT_MAX; - sock_owned_by_me((const struct sock *)msk); + msk_owned_by_me(msk); if (__mptcp_check_fallback(msk)) return NULL; @@ -2724,8 +2723,8 @@ static int mptcp_init_sock(struct sock *sk) mptcp_ca_reset(sk); sk_sockets_allocated_inc(sk); - sk->sk_rcvbuf = READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_rmem[1]); - sk->sk_sndbuf = READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_wmem[1]); + sk->sk_rcvbuf = READ_ONCE(net->ipv4.sysctl_tcp_rmem[1]); + sk->sk_sndbuf = READ_ONCE(net->ipv4.sysctl_tcp_wmem[1]); return 0; } @@ -2892,6 +2891,12 @@ static __poll_t mptcp_check_readable(struct mptcp_sock *msk) return EPOLLIN | EPOLLRDNORM; } +static void mptcp_listen_inuse_dec(struct sock *sk) +{ + if (inet_sk_state_load(sk) == TCP_LISTEN) + sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1); +} + bool __mptcp_close(struct sock *sk, long timeout) { struct mptcp_subflow_context *subflow; @@ -2901,6 +2906,7 @@ bool __mptcp_close(struct sock *sk, long timeout) sk->sk_shutdown = SHUTDOWN_MASK; if ((1 << sk->sk_state) & (TCPF_LISTEN | TCPF_CLOSE)) { + mptcp_listen_inuse_dec(sk); inet_sk_state_store(sk, TCP_CLOSE); goto cleanup; } @@ -3001,6 +3007,7 @@ static int mptcp_disconnect(struct sock *sk, int flags) if (msk->fastopening) return 0; + mptcp_listen_inuse_dec(sk); inet_sk_state_store(sk, TCP_CLOSE); mptcp_stop_timer(sk); @@ -3639,12 +3646,13 @@ static int mptcp_stream_connect(struct socket *sock, struct sockaddr *uaddr, static int mptcp_listen(struct socket *sock, int backlog) { struct mptcp_sock *msk = mptcp_sk(sock->sk); + struct sock *sk = sock->sk; struct socket *ssock; int err; pr_debug("msk=%p", msk); - lock_sock(sock->sk); + lock_sock(sk); ssock = __mptcp_nmpc_socket(msk); if (!ssock) { err = -EINVAL; @@ -3652,18 +3660,20 @@ static int mptcp_listen(struct socket *sock, int backlog) } mptcp_token_destroy(msk); - inet_sk_state_store(sock->sk, TCP_LISTEN); - sock_set_flag(sock->sk, SOCK_RCU_FREE); + inet_sk_state_store(sk, TCP_LISTEN); + sock_set_flag(sk, SOCK_RCU_FREE); err = ssock->ops->listen(ssock, backlog); - inet_sk_state_store(sock->sk, inet_sk_state_load(ssock->sk)); - if (!err) - mptcp_copy_inaddrs(sock->sk, ssock->sk); + inet_sk_state_store(sk, inet_sk_state_load(ssock->sk)); + if (!err) { + sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1); + mptcp_copy_inaddrs(sk, ssock->sk); + } mptcp_event_pm_listener(ssock->sk, MPTCP_EVENT_LISTENER_CREATED); unlock: - release_sock(sock->sk); + release_sock(sk); return err; } diff --git a/net/mptcp/protocol.h b/net/mptcp/protocol.h index a0d1658ce59e..5a8af5657796 100644 --- a/net/mptcp/protocol.h +++ b/net/mptcp/protocol.h @@ -754,7 +754,7 @@ static inline void mptcp_token_init_request(struct request_sock *req) int mptcp_token_new_request(struct request_sock *req); void mptcp_token_destroy_request(struct request_sock *req); -int mptcp_token_new_connect(struct sock *sk); +int mptcp_token_new_connect(struct sock *ssk); void mptcp_token_accept(struct mptcp_subflow_request_sock *r, struct mptcp_sock *msk); bool mptcp_token_exists(u32 token); diff --git a/net/mptcp/sockopt.c b/net/mptcp/sockopt.c index d4b1e6ec1b36..582ed93bcc8a 100644 --- a/net/mptcp/sockopt.c +++ b/net/mptcp/sockopt.c @@ -18,7 +18,7 @@ static struct sock *__mptcp_tcp_fallback(struct mptcp_sock *msk) { - sock_owned_by_me((const struct sock *)msk); + msk_owned_by_me(msk); if (likely(!__mptcp_check_fallback(msk))) return NULL; diff --git a/net/mptcp/token.c b/net/mptcp/token.c index 65430f314a68..5bb924534387 100644 --- a/net/mptcp/token.c +++ b/net/mptcp/token.c @@ -134,7 +134,7 @@ int mptcp_token_new_request(struct request_sock *req) /** * mptcp_token_new_connect - create new key/idsn/token for subflow - * @sk: the socket that will initiate a connection + * @ssk: the socket that will initiate a connection * * This function is called when a new outgoing mptcp connection is * initiated. @@ -148,11 +148,12 @@ int mptcp_token_new_request(struct request_sock *req) * * returns 0 on success. */ -int mptcp_token_new_connect(struct sock *sk) +int mptcp_token_new_connect(struct sock *ssk) { - struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk); + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(ssk); struct mptcp_sock *msk = mptcp_sk(subflow->conn); int retries = MPTCP_TOKEN_MAX_RETRIES; + struct sock *sk = subflow->conn; struct token_bucket *bucket; again: @@ -169,12 +170,13 @@ again: } pr_debug("ssk=%p, local_key=%llu, token=%u, idsn=%llu\n", - sk, subflow->local_key, subflow->token, subflow->idsn); + ssk, subflow->local_key, subflow->token, subflow->idsn); WRITE_ONCE(msk->token, subflow->token); __sk_nulls_add_node_rcu((struct sock *)msk, &bucket->msk_chain); bucket->chain_len++; spin_unlock_bh(&bucket->lock); + sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1); return 0; } @@ -190,8 +192,10 @@ void mptcp_token_accept(struct mptcp_subflow_request_sock *req, struct mptcp_sock *msk) { struct mptcp_subflow_request_sock *pos; + struct sock *sk = (struct sock *)msk; struct token_bucket *bucket; + sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1); bucket = token_bucket(req->token); spin_lock_bh(&bucket->lock); @@ -370,12 +374,14 @@ void mptcp_token_destroy_request(struct request_sock *req) */ void mptcp_token_destroy(struct mptcp_sock *msk) { + struct sock *sk = (struct sock *)msk; struct token_bucket *bucket; struct mptcp_sock *pos; if (sk_unhashed((struct sock *)msk)) return; + sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1); bucket = token_bucket(msk->token); spin_lock_bh(&bucket->lock); pos = __token_lookup_msk(bucket, msk->token); diff --git a/net/mptcp/token_test.c b/net/mptcp/token_test.c index 5d984bec1cd8..0758865ab658 100644 --- a/net/mptcp/token_test.c +++ b/net/mptcp/token_test.c @@ -57,6 +57,9 @@ static struct mptcp_sock *build_msk(struct kunit *test) KUNIT_EXPECT_NOT_ERR_OR_NULL(test, msk); refcount_set(&((struct sock *)msk)->sk_refcnt, 1); sock_net_set((struct sock *)msk, &init_net); + + /* be sure the token helpers can dereference sk->sk_prot */ + ((struct sock *)msk)->sk_prot = &tcp_prot; return msk; } diff --git a/tools/testing/selftests/net/mptcp/diag.sh b/tools/testing/selftests/net/mptcp/diag.sh index 24bcd7b9bdb2..ef628b16fe9b 100755 --- a/tools/testing/selftests/net/mptcp/diag.sh +++ b/tools/testing/selftests/net/mptcp/diag.sh @@ -17,6 +17,11 @@ flush_pids() sleep 1.1 ip netns pids "${ns}" | xargs --no-run-if-empty kill -SIGUSR1 &>/dev/null + + for _ in $(seq 10); do + [ -z "$(ip netns pids "${ns}")" ] && break + sleep 0.1 + done } cleanup() @@ -37,15 +42,20 @@ if [ $? -ne 0 ];then exit $ksft_skip fi +get_msk_inuse() +{ + ip netns exec $ns cat /proc/net/protocols | awk '$1~/^MPTCP$/{print $3}' +} + __chk_nr() { - local condition="$1" + local command="$1" local expected=$2 local msg nr shift 2 msg=$* - nr=$(ss -inmHMN $ns | $condition) + nr=$(eval $command) printf "%-50s" "$msg" if [ $nr != $expected ]; then @@ -57,9 +67,17 @@ __chk_nr() test_cnt=$((test_cnt+1)) } +__chk_msk_nr() +{ + local condition=$1 + shift 1 + + __chk_nr "ss -inmHMN $ns | $condition" $* +} + chk_msk_nr() { - __chk_nr "grep -c token:" $* + __chk_msk_nr "grep -c token:" $* } wait_msk_nr() @@ -97,12 +115,12 @@ wait_msk_nr() chk_msk_fallback_nr() { - __chk_nr "grep -c fallback" $* + __chk_msk_nr "grep -c fallback" $* } chk_msk_remote_key_nr() { - __chk_nr "grep -c remote_key" $* + __chk_msk_nr "grep -c remote_key" $* } __chk_listen() @@ -142,6 +160,26 @@ chk_msk_listen() nr=$(ss -Ml $filter | wc -l) } +chk_msk_inuse() +{ + local expected=$1 + local listen_nr + + shift 1 + + listen_nr=$(ss -N "${ns}" -Ml | grep -c LISTEN) + expected=$((expected + listen_nr)) + + for _ in $(seq 10); do + if [ $(get_msk_inuse) -eq $expected ];then + break + fi + sleep 0.1 + done + + __chk_nr get_msk_inuse $expected $* +} + # $1: ns, $2: port wait_local_port_listen() { @@ -195,8 +233,10 @@ wait_connected $ns 10000 chk_msk_nr 2 "after MPC handshake " chk_msk_remote_key_nr 2 "....chk remote_key" chk_msk_fallback_nr 0 "....chk no fallback" +chk_msk_inuse 2 "....chk 2 msk in use" flush_pids +chk_msk_inuse 0 "....chk 0 msk in use after flush" echo "a" | \ timeout ${timeout_test} \ @@ -211,8 +251,11 @@ echo "b" | \ 127.0.0.1 >/dev/null & wait_connected $ns 10001 chk_msk_fallback_nr 1 "check fallback" +chk_msk_inuse 1 "....chk 1 msk in use" flush_pids +chk_msk_inuse 0 "....chk 0 msk in use after flush" + NR_CLIENTS=100 for I in `seq 1 $NR_CLIENTS`; do echo "a" | \ @@ -232,6 +275,9 @@ for I in `seq 1 $NR_CLIENTS`; do done wait_msk_nr $((NR_CLIENTS*2)) "many msk socket present" +chk_msk_inuse $((NR_CLIENTS*2)) "....chk many msk in use" flush_pids +chk_msk_inuse 0 "....chk 0 msk in use after flush" + exit $ret diff --git a/tools/testing/selftests/net/mptcp/mptcp_connect.c b/tools/testing/selftests/net/mptcp/mptcp_connect.c index 8a8266957bc5..b25a31445ded 100644 --- a/tools/testing/selftests/net/mptcp/mptcp_connect.c +++ b/tools/testing/selftests/net/mptcp/mptcp_connect.c @@ -627,7 +627,7 @@ static int copyfd_io_poll(int infd, int peerfd, int outfd, char rbuf[8192]; ssize_t len; - if (fds.events == 0) + if (fds.events == 0 || quit) break; switch (poll(&fds, 1, poll_timeout)) { @@ -733,7 +733,7 @@ static int copyfd_io_poll(int infd, int peerfd, int outfd, } /* leave some time for late join/announce */ - if (cfg_remove) + if (cfg_remove && !quit) usleep(cfg_wait); return 0;