net: sk_msg: Simplify sk_psock initialization
Initializing psock->sk_proto and other saved callbacks is only done in sk_psock_update_proto, after sk_psock_init has returned. The logic for this is difficult to follow, and needlessly complex. Instead, initialize psock->sk_proto whenever we allocate a new psock. Additionally, assert the following invariants: * The SK has no ULP: ULP does it's own finagling of sk->sk_prot * sk_user_data is unused: we need it to store sk_psock Protect our access to sk_user_data with sk_callback_lock, which is what other users like reuseport arrays, etc. do. The result is that an sk_psock is always fully initialized, and that psock->sk_proto is always the "original" struct proto. The latter allows us to use psock->sk_proto when initializing IPv6 TCP / UDP callbacks for sockmap. Signed-off-by: Lorenz Bauer <lmb@cloudflare.com> Signed-off-by: Alexei Starovoitov <ast@kernel.org> Acked-by: John Fastabend <john.fastabend@gmail.com> Link: https://lore.kernel.org/bpf/20200821102948.21918-2-lmb@cloudflare.com
This commit is contained in:
parent
dca5612f8e
commit
7b219da43f
@ -340,23 +340,6 @@ static inline void sk_psock_update_proto(struct sock *sk,
|
|||||||
struct sk_psock *psock,
|
struct sk_psock *psock,
|
||||||
struct proto *ops)
|
struct proto *ops)
|
||||||
{
|
{
|
||||||
/* Initialize saved callbacks and original proto only once, since this
|
|
||||||
* function may be called multiple times for a psock, e.g. when
|
|
||||||
* psock->progs.msg_parser is updated.
|
|
||||||
*
|
|
||||||
* Since we've not installed the new proto, psock is not yet in use and
|
|
||||||
* we can initialize it without synchronization.
|
|
||||||
*/
|
|
||||||
if (!psock->sk_proto) {
|
|
||||||
struct proto *orig = READ_ONCE(sk->sk_prot);
|
|
||||||
|
|
||||||
psock->saved_unhash = orig->unhash;
|
|
||||||
psock->saved_close = orig->close;
|
|
||||||
psock->saved_write_space = sk->sk_write_space;
|
|
||||||
|
|
||||||
psock->sk_proto = orig;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Pairs with lockless read in sk_clone_lock() */
|
/* Pairs with lockless read in sk_clone_lock() */
|
||||||
WRITE_ONCE(sk->sk_prot, ops);
|
WRITE_ONCE(sk->sk_prot, ops);
|
||||||
}
|
}
|
||||||
|
@ -494,14 +494,34 @@ end:
|
|||||||
|
|
||||||
struct sk_psock *sk_psock_init(struct sock *sk, int node)
|
struct sk_psock *sk_psock_init(struct sock *sk, int node)
|
||||||
{
|
{
|
||||||
struct sk_psock *psock = kzalloc_node(sizeof(*psock),
|
struct sk_psock *psock;
|
||||||
GFP_ATOMIC | __GFP_NOWARN,
|
struct proto *prot;
|
||||||
node);
|
|
||||||
if (!psock)
|
|
||||||
return NULL;
|
|
||||||
|
|
||||||
|
write_lock_bh(&sk->sk_callback_lock);
|
||||||
|
|
||||||
|
if (inet_csk_has_ulp(sk)) {
|
||||||
|
psock = ERR_PTR(-EINVAL);
|
||||||
|
goto out;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sk->sk_user_data) {
|
||||||
|
psock = ERR_PTR(-EBUSY);
|
||||||
|
goto out;
|
||||||
|
}
|
||||||
|
|
||||||
|
psock = kzalloc_node(sizeof(*psock), GFP_ATOMIC | __GFP_NOWARN, node);
|
||||||
|
if (!psock) {
|
||||||
|
psock = ERR_PTR(-ENOMEM);
|
||||||
|
goto out;
|
||||||
|
}
|
||||||
|
|
||||||
|
prot = READ_ONCE(sk->sk_prot);
|
||||||
psock->sk = sk;
|
psock->sk = sk;
|
||||||
psock->eval = __SK_NONE;
|
psock->eval = __SK_NONE;
|
||||||
|
psock->sk_proto = prot;
|
||||||
|
psock->saved_unhash = prot->unhash;
|
||||||
|
psock->saved_close = prot->close;
|
||||||
|
psock->saved_write_space = sk->sk_write_space;
|
||||||
|
|
||||||
INIT_LIST_HEAD(&psock->link);
|
INIT_LIST_HEAD(&psock->link);
|
||||||
spin_lock_init(&psock->link_lock);
|
spin_lock_init(&psock->link_lock);
|
||||||
@ -516,6 +536,8 @@ struct sk_psock *sk_psock_init(struct sock *sk, int node)
|
|||||||
rcu_assign_sk_user_data_nocopy(sk, psock);
|
rcu_assign_sk_user_data_nocopy(sk, psock);
|
||||||
sock_hold(sk);
|
sock_hold(sk);
|
||||||
|
|
||||||
|
out:
|
||||||
|
write_unlock_bh(&sk->sk_callback_lock);
|
||||||
return psock;
|
return psock;
|
||||||
}
|
}
|
||||||
EXPORT_SYMBOL_GPL(sk_psock_init);
|
EXPORT_SYMBOL_GPL(sk_psock_init);
|
||||||
|
@ -184,8 +184,6 @@ static int sock_map_init_proto(struct sock *sk, struct sk_psock *psock)
|
|||||||
{
|
{
|
||||||
struct proto *prot;
|
struct proto *prot;
|
||||||
|
|
||||||
sock_owned_by_me(sk);
|
|
||||||
|
|
||||||
switch (sk->sk_type) {
|
switch (sk->sk_type) {
|
||||||
case SOCK_STREAM:
|
case SOCK_STREAM:
|
||||||
prot = tcp_bpf_get_proto(sk, psock);
|
prot = tcp_bpf_get_proto(sk, psock);
|
||||||
@ -272,8 +270,8 @@ static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
psock = sk_psock_init(sk, map->numa_node);
|
psock = sk_psock_init(sk, map->numa_node);
|
||||||
if (!psock) {
|
if (IS_ERR(psock)) {
|
||||||
ret = -ENOMEM;
|
ret = PTR_ERR(psock);
|
||||||
goto out_progs;
|
goto out_progs;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -322,8 +320,8 @@ static int sock_map_link_no_progs(struct bpf_map *map, struct sock *sk)
|
|||||||
|
|
||||||
if (!psock) {
|
if (!psock) {
|
||||||
psock = sk_psock_init(sk, map->numa_node);
|
psock = sk_psock_init(sk, map->numa_node);
|
||||||
if (!psock)
|
if (IS_ERR(psock))
|
||||||
return -ENOMEM;
|
return PTR_ERR(psock);
|
||||||
}
|
}
|
||||||
|
|
||||||
ret = sock_map_init_proto(sk, psock);
|
ret = sock_map_init_proto(sk, psock);
|
||||||
@ -478,8 +476,6 @@ static int sock_map_update_common(struct bpf_map *map, u32 idx,
|
|||||||
return -EINVAL;
|
return -EINVAL;
|
||||||
if (unlikely(idx >= map->max_entries))
|
if (unlikely(idx >= map->max_entries))
|
||||||
return -E2BIG;
|
return -E2BIG;
|
||||||
if (inet_csk_has_ulp(sk))
|
|
||||||
return -EINVAL;
|
|
||||||
|
|
||||||
link = sk_psock_init_link();
|
link = sk_psock_init_link();
|
||||||
if (!link)
|
if (!link)
|
||||||
@ -855,8 +851,6 @@ static int sock_hash_update_common(struct bpf_map *map, void *key,
|
|||||||
WARN_ON_ONCE(!rcu_read_lock_held());
|
WARN_ON_ONCE(!rcu_read_lock_held());
|
||||||
if (unlikely(flags > BPF_EXIST))
|
if (unlikely(flags > BPF_EXIST))
|
||||||
return -EINVAL;
|
return -EINVAL;
|
||||||
if (inet_csk_has_ulp(sk))
|
|
||||||
return -EINVAL;
|
|
||||||
|
|
||||||
link = sk_psock_init_link();
|
link = sk_psock_init_link();
|
||||||
if (!link)
|
if (!link)
|
||||||
|
@ -567,10 +567,9 @@ static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS],
|
|||||||
prot[TCP_BPF_TX].sendpage = tcp_bpf_sendpage;
|
prot[TCP_BPF_TX].sendpage = tcp_bpf_sendpage;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void tcp_bpf_check_v6_needs_rebuild(struct sock *sk, struct proto *ops)
|
static void tcp_bpf_check_v6_needs_rebuild(struct proto *ops)
|
||||||
{
|
{
|
||||||
if (sk->sk_family == AF_INET6 &&
|
if (unlikely(ops != smp_load_acquire(&tcpv6_prot_saved))) {
|
||||||
unlikely(ops != smp_load_acquire(&tcpv6_prot_saved))) {
|
|
||||||
spin_lock_bh(&tcpv6_prot_lock);
|
spin_lock_bh(&tcpv6_prot_lock);
|
||||||
if (likely(ops != tcpv6_prot_saved)) {
|
if (likely(ops != tcpv6_prot_saved)) {
|
||||||
tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV6], ops);
|
tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV6], ops);
|
||||||
@ -603,13 +602,11 @@ struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
|
|||||||
int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
|
int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
|
||||||
int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE;
|
int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE;
|
||||||
|
|
||||||
if (!psock->sk_proto) {
|
if (sk->sk_family == AF_INET6) {
|
||||||
struct proto *ops = READ_ONCE(sk->sk_prot);
|
if (tcp_bpf_assert_proto_ops(psock->sk_proto))
|
||||||
|
|
||||||
if (tcp_bpf_assert_proto_ops(ops))
|
|
||||||
return ERR_PTR(-EINVAL);
|
return ERR_PTR(-EINVAL);
|
||||||
|
|
||||||
tcp_bpf_check_v6_needs_rebuild(sk, ops);
|
tcp_bpf_check_v6_needs_rebuild(psock->sk_proto);
|
||||||
}
|
}
|
||||||
|
|
||||||
return &tcp_bpf_prots[family][config];
|
return &tcp_bpf_prots[family][config];
|
||||||
|
@ -22,10 +22,9 @@ static void udp_bpf_rebuild_protos(struct proto *prot, const struct proto *base)
|
|||||||
prot->close = sock_map_close;
|
prot->close = sock_map_close;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void udp_bpf_check_v6_needs_rebuild(struct sock *sk, struct proto *ops)
|
static void udp_bpf_check_v6_needs_rebuild(struct proto *ops)
|
||||||
{
|
{
|
||||||
if (sk->sk_family == AF_INET6 &&
|
if (unlikely(ops != smp_load_acquire(&udpv6_prot_saved))) {
|
||||||
unlikely(ops != smp_load_acquire(&udpv6_prot_saved))) {
|
|
||||||
spin_lock_bh(&udpv6_prot_lock);
|
spin_lock_bh(&udpv6_prot_lock);
|
||||||
if (likely(ops != udpv6_prot_saved)) {
|
if (likely(ops != udpv6_prot_saved)) {
|
||||||
udp_bpf_rebuild_protos(&udp_bpf_prots[UDP_BPF_IPV6], ops);
|
udp_bpf_rebuild_protos(&udp_bpf_prots[UDP_BPF_IPV6], ops);
|
||||||
@ -46,8 +45,8 @@ struct proto *udp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
|
|||||||
{
|
{
|
||||||
int family = sk->sk_family == AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6;
|
int family = sk->sk_family == AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6;
|
||||||
|
|
||||||
if (!psock->sk_proto)
|
if (sk->sk_family == AF_INET6)
|
||||||
udp_bpf_check_v6_needs_rebuild(sk, READ_ONCE(sk->sk_prot));
|
udp_bpf_check_v6_needs_rebuild(psock->sk_proto);
|
||||||
|
|
||||||
return &udp_bpf_prots[family];
|
return &udp_bpf_prots[family];
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user