diff --git a/include/net/inet_sock.h b/include/net/inet_sock.h index 9e1111f5915b..d81b7f85819e 100644 --- a/include/net/inet_sock.h +++ b/include/net/inet_sock.h @@ -252,6 +252,11 @@ struct inet_sock { #define IP_CMSG_CHECKSUM BIT(7) #define IP_CMSG_RECVFRAGSIZE BIT(8) +static inline bool sk_is_inet(struct sock *sk) +{ + return sk->sk_family == AF_INET || sk->sk_family == AF_INET6; +} + /** * sk_to_full_sk - Access to a full socket * @sk: pointer to a socket diff --git a/net/core/skmsg.c b/net/core/skmsg.c index cc381165ea08..ede0af308f40 100644 --- a/net/core/skmsg.c +++ b/net/core/skmsg.c @@ -695,6 +695,11 @@ struct sk_psock *sk_psock_init(struct sock *sk, int node) write_lock_bh(&sk->sk_callback_lock); + if (sk_is_inet(sk) && inet_csk_has_ulp(sk)) { + psock = ERR_PTR(-EINVAL); + goto out; + } + if (sk->sk_user_data) { psock = ERR_PTR(-EBUSY); goto out; diff --git a/net/ipv4/tcp_bpf.c b/net/ipv4/tcp_bpf.c index 1cdcb4df0eb7..2c597a4e429a 100644 --- a/net/ipv4/tcp_bpf.c +++ b/net/ipv4/tcp_bpf.c @@ -612,9 +612,6 @@ int tcp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore) return 0; } - if (inet_csk_has_ulp(sk)) - return -EINVAL; - if (sk->sk_family == AF_INET6) { if (tcp_bpf_assert_proto_ops(psock->sk_proto)) return -EINVAL; diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index 9aac9c60d786..62b1c5e32bbd 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -790,6 +790,8 @@ static void tls_update(struct sock *sk, struct proto *p, { struct tls_context *ctx; + WARN_ON_ONCE(sk->sk_prot == p); + ctx = tls_get_ctx(sk); if (likely(ctx)) { ctx->sk_write_space = write_space;