2020-03-09 11:12:38 +00:00
// SPDX-License-Identifier: GPL-2.0
/* Copyright (c) 2020 Cloudflare Ltd https://cloudflare.com */
# include <linux/skmsg.h>
# include <net/sock.h>
# include <net/udp.h>
2021-03-30 19:32:34 -07:00
# include <net/inet_common.h>
# include "udp_impl.h"
static struct proto * udpv6_prot_saved __read_mostly ;
static int sk_udp_recvmsg ( struct sock * sk , struct msghdr * msg , size_t len ,
int noblock , int flags , int * addr_len )
{
# if IS_ENABLED(CONFIG_IPV6)
if ( sk - > sk_family = = AF_INET6 )
return udpv6_prot_saved - > recvmsg ( sk , msg , len , noblock , flags ,
addr_len ) ;
# endif
return udp_prot . recvmsg ( sk , msg , len , noblock , flags , addr_len ) ;
}
2021-06-14 19:13:35 -07:00
static bool udp_sk_has_data ( struct sock * sk )
{
return ! skb_queue_empty ( & udp_sk ( sk ) - > reader_queue ) | |
! skb_queue_empty ( & sk - > sk_receive_queue ) ;
}
static bool psock_has_data ( struct sk_psock * psock )
{
return ! skb_queue_empty ( & psock - > ingress_skb ) | |
! sk_psock_queue_empty ( psock ) ;
}
# define udp_msg_has_data(__sk, __psock) \
( { udp_sk_has_data ( __sk ) | | psock_has_data ( __psock ) ; } )
2021-06-29 15:45:27 -07:00
static int udp_msg_wait_data ( struct sock * sk , struct sk_psock * psock ,
long timeo )
2021-06-14 19:13:35 -07:00
{
DEFINE_WAIT_FUNC ( wait , woken_wake_function ) ;
int ret = 0 ;
if ( sk - > sk_shutdown & RCV_SHUTDOWN )
return 1 ;
if ( ! timeo )
return ret ;
add_wait_queue ( sk_sleep ( sk ) , & wait ) ;
sk_set_bit ( SOCKWQ_ASYNC_WAITDATA , sk ) ;
ret = udp_msg_has_data ( sk , psock ) ;
if ( ! ret ) {
wait_woken ( & wait , TASK_INTERRUPTIBLE , timeo ) ;
ret = udp_msg_has_data ( sk , psock ) ;
}
sk_clear_bit ( SOCKWQ_ASYNC_WAITDATA , sk ) ;
remove_wait_queue ( sk_sleep ( sk ) , & wait ) ;
return ret ;
}
2021-03-30 19:32:34 -07:00
static int udp_bpf_recvmsg ( struct sock * sk , struct msghdr * msg , size_t len ,
int nonblock , int flags , int * addr_len )
{
struct sk_psock * psock ;
int copied , ret ;
if ( unlikely ( flags & MSG_ERRQUEUE ) )
return inet_recv_error ( sk , msg , len , addr_len ) ;
psock = sk_psock_get ( sk ) ;
if ( unlikely ( ! psock ) )
return sk_udp_recvmsg ( sk , msg , len , nonblock , flags , addr_len ) ;
2021-06-14 19:13:35 -07:00
if ( ! psock_has_data ( psock ) ) {
2021-03-30 19:32:34 -07:00
ret = sk_udp_recvmsg ( sk , msg , len , nonblock , flags , addr_len ) ;
goto out ;
}
msg_bytes_ready :
copied = sk_msg_recvmsg ( sk , psock , msg , len , flags ) ;
if ( ! copied ) {
long timeo ;
2021-05-16 19:23:48 -07:00
int data ;
2021-03-30 19:32:34 -07:00
timeo = sock_rcvtimeo ( sk , nonblock ) ;
2021-06-29 15:45:27 -07:00
data = udp_msg_wait_data ( sk , psock , timeo ) ;
2021-03-30 19:32:34 -07:00
if ( data ) {
2021-06-14 19:13:35 -07:00
if ( psock_has_data ( psock ) )
2021-03-30 19:32:34 -07:00
goto msg_bytes_ready ;
ret = sk_udp_recvmsg ( sk , msg , len , nonblock , flags , addr_len ) ;
goto out ;
}
copied = - EAGAIN ;
}
ret = copied ;
out :
sk_psock_put ( sk , psock ) ;
return ret ;
}
2020-03-09 11:12:38 +00:00
enum {
UDP_BPF_IPV4 ,
UDP_BPF_IPV6 ,
UDP_BPF_NUM_PROTS ,
} ;
static DEFINE_SPINLOCK ( udpv6_prot_lock ) ;
static struct proto udp_bpf_prots [ UDP_BPF_NUM_PROTS ] ;
static void udp_bpf_rebuild_protos ( struct proto * prot , const struct proto * base )
{
* prot = * base ;
prot - > close = sock_map_close ;
2021-03-30 19:32:34 -07:00
prot - > recvmsg = udp_bpf_recvmsg ;
2021-10-08 13:33:05 -07:00
prot - > sock_is_readable = sk_msg_is_readable ;
2020-03-09 11:12:38 +00:00
}
2020-08-21 11:29:43 +01:00
static void udp_bpf_check_v6_needs_rebuild ( struct proto * ops )
2020-03-09 11:12:38 +00:00
{
2020-08-21 11:29:43 +01:00
if ( unlikely ( ops ! = smp_load_acquire ( & udpv6_prot_saved ) ) ) {
2020-03-09 11:12:38 +00:00
spin_lock_bh ( & udpv6_prot_lock ) ;
if ( likely ( ops ! = udpv6_prot_saved ) ) {
udp_bpf_rebuild_protos ( & udp_bpf_prots [ UDP_BPF_IPV6 ] , ops ) ;
smp_store_release ( & udpv6_prot_saved , ops ) ;
}
spin_unlock_bh ( & udpv6_prot_lock ) ;
}
}
static int __init udp_bpf_v4_build_proto ( void )
{
udp_bpf_rebuild_protos ( & udp_bpf_prots [ UDP_BPF_IPV4 ] , & udp_prot ) ;
return 0 ;
}
2021-07-14 17:47:50 +02:00
late_initcall ( udp_bpf_v4_build_proto ) ;
2020-03-09 11:12:38 +00:00
2021-04-06 20:21:11 -07:00
int udp_bpf_update_proto ( struct sock * sk , struct sk_psock * psock , bool restore )
2020-03-09 11:12:38 +00:00
{
int family = sk - > sk_family = = AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6 ;
2021-03-30 19:32:31 -07:00
if ( restore ) {
sk - > sk_write_space = psock - > saved_write_space ;
WRITE_ONCE ( sk - > sk_prot , psock - > sk_proto ) ;
return 0 ;
}
2020-03-09 11:12:38 +00:00
2020-08-21 11:29:43 +01:00
if ( sk - > sk_family = = AF_INET6 )
udp_bpf_check_v6_needs_rebuild ( psock - > sk_proto ) ;
2020-03-09 11:12:38 +00:00
2021-03-30 19:32:31 -07:00
WRITE_ONCE ( sk - > sk_prot , & udp_bpf_prots [ family ] ) ;
return 0 ;
2020-03-09 11:12:38 +00:00
}
2021-03-30 19:32:31 -07:00
EXPORT_SYMBOL_GPL ( udp_bpf_update_proto ) ;