2022-11-03 16:57:44 +01:00
// SPDX-License-Identifier: GPL-2.0-only
/*
* HID - BPF support for Linux
*
* Copyright ( c ) 2022 Benjamin Tissoires
*/
# include <linux/bitops.h>
# include <linux/btf.h>
# include <linux/btf_ids.h>
# include <linux/circ_buf.h>
# include <linux/filter.h>
# include <linux/hid.h>
# include <linux/hid_bpf.h>
# include <linux/init.h>
# include <linux/module.h>
# include <linux/workqueue.h>
# include "hid_bpf_dispatch.h"
# include "entrypoints/entrypoints.lskel.h"
# define HID_BPF_MAX_PROGS 1024 / * keep this in sync with preloaded bpf,
* needs to be a power of 2 as we use it as
* a circular buffer
*/
# define NEXT(idx) (((idx) + 1) & (HID_BPF_MAX_PROGS - 1))
# define PREV(idx) (((idx) - 1) & (HID_BPF_MAX_PROGS - 1))
/*
* represents one attached program stored in the hid jump table
*/
struct hid_bpf_prog_entry {
struct bpf_prog * prog ;
struct hid_device * hdev ;
enum hid_bpf_prog_type type ;
u16 idx ;
} ;
struct hid_bpf_jmp_table {
struct bpf_map * map ;
struct hid_bpf_prog_entry entries [ HID_BPF_MAX_PROGS ] ; /* compacted list, circular buffer */
int tail , head ;
struct bpf_prog * progs [ HID_BPF_MAX_PROGS ] ; /* idx -> progs mapping */
unsigned long enabled [ BITS_TO_LONGS ( HID_BPF_MAX_PROGS ) ] ;
} ;
# define FOR_ENTRIES(__i, __start, __end) \
for ( __i = __start ; CIRC_CNT ( __end , __i , HID_BPF_MAX_PROGS ) ; __i = NEXT ( __i ) )
static struct hid_bpf_jmp_table jmp_table ;
static DEFINE_MUTEX ( hid_bpf_attach_lock ) ; /* held when attaching/detaching programs */
static void hid_bpf_release_progs ( struct work_struct * work ) ;
static DECLARE_WORK ( release_work , hid_bpf_release_progs ) ;
BTF_ID_LIST ( hid_bpf_btf_ids )
BTF_ID ( func , hid_bpf_device_event ) /* HID_BPF_PROG_TYPE_DEVICE_EVENT */
2022-11-03 16:57:51 +01:00
BTF_ID ( func , hid_bpf_rdesc_fixup ) /* HID_BPF_PROG_TYPE_RDESC_FIXUP */
2022-11-03 16:57:44 +01:00
static int hid_bpf_max_programs ( enum hid_bpf_prog_type type )
{
switch ( type ) {
case HID_BPF_PROG_TYPE_DEVICE_EVENT :
return HID_BPF_MAX_PROGS_PER_DEV ;
2022-11-03 16:57:51 +01:00
case HID_BPF_PROG_TYPE_RDESC_FIXUP :
return 1 ;
2022-11-03 16:57:44 +01:00
default :
return - EINVAL ;
}
}
static int hid_bpf_program_count ( struct hid_device * hdev ,
struct bpf_prog * prog ,
enum hid_bpf_prog_type type )
{
int i , n = 0 ;
if ( type > = HID_BPF_PROG_TYPE_MAX )
return - EINVAL ;
FOR_ENTRIES ( i , jmp_table . tail , jmp_table . head ) {
struct hid_bpf_prog_entry * entry = & jmp_table . entries [ i ] ;
if ( type ! = HID_BPF_PROG_TYPE_UNDEF & & entry - > type ! = type )
continue ;
if ( hdev & & entry - > hdev ! = hdev )
continue ;
if ( prog & & entry - > prog ! = prog )
continue ;
n + + ;
}
return n ;
}
__weak noinline int __hid_bpf_tail_call ( struct hid_bpf_ctx * ctx )
{
return 0 ;
}
int hid_bpf_prog_run ( struct hid_device * hdev , enum hid_bpf_prog_type type ,
struct hid_bpf_ctx_kern * ctx_kern )
{
struct hid_bpf_prog_list * prog_list ;
int i , idx , err = 0 ;
rcu_read_lock ( ) ;
prog_list = rcu_dereference ( hdev - > bpf . progs [ type ] ) ;
if ( ! prog_list )
goto out_unlock ;
for ( i = 0 ; i < prog_list - > prog_cnt ; i + + ) {
idx = prog_list - > prog_idx [ i ] ;
if ( ! test_bit ( idx , jmp_table . enabled ) )
continue ;
ctx_kern - > ctx . index = idx ;
err = __hid_bpf_tail_call ( & ctx_kern - > ctx ) ;
2022-11-03 16:57:47 +01:00
if ( err < 0 )
2022-11-03 16:57:44 +01:00
break ;
2022-11-03 16:57:47 +01:00
if ( err )
ctx_kern - > ctx . retval = err ;
2022-11-03 16:57:44 +01:00
}
out_unlock :
rcu_read_unlock ( ) ;
return err ;
}
/*
* assign the list of programs attached to a given hid device .
*/
static void __hid_bpf_set_hdev_progs ( struct hid_device * hdev , struct hid_bpf_prog_list * new_list ,
enum hid_bpf_prog_type type )
{
struct hid_bpf_prog_list * old_list ;
spin_lock ( & hdev - > bpf . progs_lock ) ;
old_list = rcu_dereference_protected ( hdev - > bpf . progs [ type ] ,
lockdep_is_held ( & hdev - > bpf . progs_lock ) ) ;
rcu_assign_pointer ( hdev - > bpf . progs [ type ] , new_list ) ;
spin_unlock ( & hdev - > bpf . progs_lock ) ;
synchronize_rcu ( ) ;
kfree ( old_list ) ;
}
/*
* allocate and populate the list of programs attached to a given hid device .
*
* Must be called under lock .
*/
static int hid_bpf_populate_hdev ( struct hid_device * hdev , enum hid_bpf_prog_type type )
{
struct hid_bpf_prog_list * new_list ;
int i ;
if ( type > = HID_BPF_PROG_TYPE_MAX | | ! hdev )
return - EINVAL ;
if ( hdev - > bpf . destroyed )
return 0 ;
new_list = kzalloc ( sizeof ( * new_list ) , GFP_KERNEL ) ;
if ( ! new_list )
return - ENOMEM ;
FOR_ENTRIES ( i , jmp_table . tail , jmp_table . head ) {
struct hid_bpf_prog_entry * entry = & jmp_table . entries [ i ] ;
if ( entry - > type = = type & & entry - > hdev = = hdev & &
test_bit ( entry - > idx , jmp_table . enabled ) )
new_list - > prog_idx [ new_list - > prog_cnt + + ] = entry - > idx ;
}
__hid_bpf_set_hdev_progs ( hdev , new_list , type ) ;
return 0 ;
}
static void __hid_bpf_do_release_prog ( int map_fd , unsigned int idx )
{
skel_map_delete_elem ( map_fd , & idx ) ;
jmp_table . progs [ idx ] = NULL ;
}
static void hid_bpf_release_progs ( struct work_struct * work )
{
int i , j , n , map_fd = - 1 ;
if ( ! jmp_table . map )
return ;
/* retrieve a fd of our prog_array map in BPF */
map_fd = skel_map_get_fd_by_id ( jmp_table . map - > id ) ;
if ( map_fd < 0 )
return ;
mutex_lock ( & hid_bpf_attach_lock ) ; /* protects against attaching new programs */
/* detach unused progs from HID devices */
FOR_ENTRIES ( i , jmp_table . tail , jmp_table . head ) {
struct hid_bpf_prog_entry * entry = & jmp_table . entries [ i ] ;
enum hid_bpf_prog_type type ;
struct hid_device * hdev ;
if ( test_bit ( entry - > idx , jmp_table . enabled ) )
continue ;
/* we have an attached prog */
if ( entry - > hdev ) {
hdev = entry - > hdev ;
type = entry - > type ;
hid_bpf_populate_hdev ( hdev , type ) ;
/* mark all other disabled progs from hdev of the given type as detached */
FOR_ENTRIES ( j , i , jmp_table . head ) {
struct hid_bpf_prog_entry * next ;
next = & jmp_table . entries [ j ] ;
if ( test_bit ( next - > idx , jmp_table . enabled ) )
continue ;
if ( next - > hdev = = hdev & & next - > type = = type )
next - > hdev = NULL ;
}
2022-11-03 16:57:51 +01:00
/* if type was rdesc fixup, reconnect device */
if ( type = = HID_BPF_PROG_TYPE_RDESC_FIXUP )
hid_bpf_reconnect ( hdev ) ;
2022-11-03 16:57:44 +01:00
}
}
/* remove all unused progs from the jump table */
FOR_ENTRIES ( i , jmp_table . tail , jmp_table . head ) {
struct hid_bpf_prog_entry * entry = & jmp_table . entries [ i ] ;
if ( test_bit ( entry - > idx , jmp_table . enabled ) )
continue ;
if ( entry - > prog )
__hid_bpf_do_release_prog ( map_fd , entry - > idx ) ;
}
/* compact the entry list */
n = jmp_table . tail ;
FOR_ENTRIES ( i , jmp_table . tail , jmp_table . head ) {
struct hid_bpf_prog_entry * entry = & jmp_table . entries [ i ] ;
if ( ! test_bit ( entry - > idx , jmp_table . enabled ) )
continue ;
jmp_table . entries [ n ] = jmp_table . entries [ i ] ;
n = NEXT ( n ) ;
}
jmp_table . head = n ;
mutex_unlock ( & hid_bpf_attach_lock ) ;
if ( map_fd > = 0 )
close_fd ( map_fd ) ;
}
static void hid_bpf_release_prog_at ( int idx )
{
int map_fd = - 1 ;
/* retrieve a fd of our prog_array map in BPF */
map_fd = skel_map_get_fd_by_id ( jmp_table . map - > id ) ;
if ( map_fd < 0 )
return ;
__hid_bpf_do_release_prog ( map_fd , idx ) ;
close ( map_fd ) ;
}
/*
* Insert the given BPF program represented by its fd in the jmp table .
* Returns the index in the jump table or a negative error .
*/
static int hid_bpf_insert_prog ( int prog_fd , struct bpf_prog * prog )
{
2022-11-03 16:57:46 +01:00
int i , index = - 1 , map_fd = - 1 , err = - EINVAL ;
2022-11-03 16:57:44 +01:00
/* retrieve a fd of our prog_array map in BPF */
map_fd = skel_map_get_fd_by_id ( jmp_table . map - > id ) ;
2022-11-03 16:57:46 +01:00
if ( map_fd < 0 ) {
2022-11-03 16:57:44 +01:00
err = - EINVAL ;
goto out ;
}
2022-11-03 16:57:46 +01:00
/* find the first available index in the jmp_table */
2022-11-03 16:57:44 +01:00
for ( i = 0 ; i < HID_BPF_MAX_PROGS ; i + + ) {
if ( ! jmp_table . progs [ i ] & & index < 0 ) {
/* mark the index as used */
jmp_table . progs [ i ] = prog ;
index = i ;
__set_bit ( i , jmp_table . enabled ) ;
}
}
if ( index < 0 ) {
err = - ENOMEM ;
goto out ;
}
/* insert the program in the jump table */
err = skel_map_update_elem ( map_fd , & index , & prog_fd , 0 ) ;
if ( err )
goto out ;
/* return the index */
err = index ;
out :
if ( err < 0 )
__hid_bpf_do_release_prog ( map_fd , index ) ;
if ( map_fd > = 0 )
close_fd ( map_fd ) ;
return err ;
}
int hid_bpf_get_prog_attach_type ( int prog_fd )
{
struct bpf_prog * prog = NULL ;
int i ;
int prog_type = HID_BPF_PROG_TYPE_UNDEF ;
prog = bpf_prog_get ( prog_fd ) ;
if ( IS_ERR ( prog ) )
return PTR_ERR ( prog ) ;
for ( i = 0 ; i < HID_BPF_PROG_TYPE_MAX ; i + + ) {
if ( hid_bpf_btf_ids [ i ] = = prog - > aux - > attach_btf_id ) {
prog_type = i ;
break ;
}
}
bpf_prog_put ( prog ) ;
return prog_type ;
}
2023-01-13 10:09:32 +01:00
static void hid_bpf_link_release ( struct bpf_link * link )
{
struct hid_bpf_link * hid_link =
container_of ( link , struct hid_bpf_link , link ) ;
__clear_bit ( hid_link - > hid_table_index , jmp_table . enabled ) ;
schedule_work ( & release_work ) ;
}
static void hid_bpf_link_dealloc ( struct bpf_link * link )
{
struct hid_bpf_link * hid_link =
container_of ( link , struct hid_bpf_link , link ) ;
kfree ( hid_link ) ;
}
static void hid_bpf_link_show_fdinfo ( const struct bpf_link * link ,
struct seq_file * seq )
{
seq_printf ( seq ,
" attach_type: \t HID-BPF \n " ) ;
}
static const struct bpf_link_ops hid_bpf_link_lops = {
. release = hid_bpf_link_release ,
. dealloc = hid_bpf_link_dealloc ,
. show_fdinfo = hid_bpf_link_show_fdinfo ,
} ;
2022-11-03 16:57:44 +01:00
/* called from syscall */
noinline int
__hid_bpf_attach_prog ( struct hid_device * hdev , enum hid_bpf_prog_type prog_type ,
int prog_fd , __u32 flags )
{
2023-01-13 10:09:32 +01:00
struct bpf_link_primer link_primer ;
struct hid_bpf_link * link ;
2022-11-03 16:57:44 +01:00
struct bpf_prog * prog = NULL ;
struct hid_bpf_prog_entry * prog_entry ;
2023-01-13 10:09:32 +01:00
int cnt , err = - EINVAL , prog_table_idx = - 1 ;
2022-11-03 16:57:44 +01:00
/* take a ref on the prog itself */
prog = bpf_prog_get ( prog_fd ) ;
if ( IS_ERR ( prog ) )
return PTR_ERR ( prog ) ;
mutex_lock ( & hid_bpf_attach_lock ) ;
2023-01-13 10:09:32 +01:00
link = kzalloc ( sizeof ( * link ) , GFP_USER ) ;
if ( ! link ) {
err = - ENOMEM ;
goto err_unlock ;
}
bpf_link_init ( & link - > link , BPF_LINK_TYPE_UNSPEC ,
& hid_bpf_link_lops , prog ) ;
2022-11-03 16:57:44 +01:00
/* do not attach too many programs to a given HID device */
cnt = hid_bpf_program_count ( hdev , NULL , prog_type ) ;
if ( cnt < 0 ) {
err = cnt ;
2023-01-13 10:09:32 +01:00
goto err_unlock ;
2022-11-03 16:57:44 +01:00
}
if ( cnt > = hid_bpf_max_programs ( prog_type ) ) {
err = - E2BIG ;
2023-01-13 10:09:32 +01:00
goto err_unlock ;
2022-11-03 16:57:44 +01:00
}
2023-01-13 10:09:32 +01:00
prog_table_idx = hid_bpf_insert_prog ( prog_fd , prog ) ;
2022-11-03 16:57:44 +01:00
/* if the jmp table is full, abort */
2023-01-13 10:09:32 +01:00
if ( prog_table_idx < 0 ) {
err = prog_table_idx ;
goto err_unlock ;
2022-11-03 16:57:44 +01:00
}
if ( flags & HID_BPF_FLAG_INSERT_HEAD ) {
/* take the previous prog_entry slot */
jmp_table . tail = PREV ( jmp_table . tail ) ;
prog_entry = & jmp_table . entries [ jmp_table . tail ] ;
} else {
/* take the next prog_entry slot */
prog_entry = & jmp_table . entries [ jmp_table . head ] ;
jmp_table . head = NEXT ( jmp_table . head ) ;
}
/* we steal the ref here */
prog_entry - > prog = prog ;
2023-01-13 10:09:32 +01:00
prog_entry - > idx = prog_table_idx ;
2022-11-03 16:57:44 +01:00
prog_entry - > hdev = hdev ;
prog_entry - > type = prog_type ;
/* finally store the index in the device list */
err = hid_bpf_populate_hdev ( hdev , prog_type ) ;
2023-01-13 10:09:32 +01:00
if ( err ) {
hid_bpf_release_prog_at ( prog_table_idx ) ;
goto err_unlock ;
}
link - > hid_table_index = prog_table_idx ;
err = bpf_link_prime ( & link - > link , & link_primer ) ;
2022-11-03 16:57:44 +01:00
if ( err )
2023-01-13 10:09:32 +01:00
goto err_unlock ;
2022-11-03 16:57:44 +01:00
mutex_unlock ( & hid_bpf_attach_lock ) ;
2023-01-13 10:09:32 +01:00
return bpf_link_settle ( & link_primer ) ;
err_unlock :
mutex_unlock ( & hid_bpf_attach_lock ) ;
2022-11-03 16:57:44 +01:00
bpf_prog_put ( prog ) ;
2023-01-13 10:09:32 +01:00
kfree ( link ) ;
2022-11-03 16:57:44 +01:00
return err ;
}
void __hid_bpf_destroy_device ( struct hid_device * hdev )
{
int type , i ;
struct hid_bpf_prog_list * prog_list ;
rcu_read_lock ( ) ;
for ( type = 0 ; type < HID_BPF_PROG_TYPE_MAX ; type + + ) {
prog_list = rcu_dereference ( hdev - > bpf . progs [ type ] ) ;
if ( ! prog_list )
continue ;
for ( i = 0 ; i < prog_list - > prog_cnt ; i + + )
__clear_bit ( prog_list - > prog_idx [ i ] , jmp_table . enabled ) ;
}
rcu_read_unlock ( ) ;
for ( type = 0 ; type < HID_BPF_PROG_TYPE_MAX ; type + + )
__hid_bpf_set_hdev_progs ( hdev , NULL , type ) ;
/* schedule release of all detached progs */
schedule_work ( & release_work ) ;
}
2023-01-13 10:09:32 +01:00
# define HID_BPF_PROGS_COUNT 1
2022-11-03 16:57:44 +01:00
static struct bpf_link * links [ HID_BPF_PROGS_COUNT ] ;
static struct entrypoints_bpf * skel ;
void hid_bpf_free_links_and_skel ( void )
{
int i ;
/* the following is enough to release all programs attached to hid */
if ( jmp_table . map )
bpf_map_put_with_uref ( jmp_table . map ) ;
for ( i = 0 ; i < ARRAY_SIZE ( links ) ; i + + ) {
if ( ! IS_ERR_OR_NULL ( links [ i ] ) )
bpf_link_put ( links [ i ] ) ;
}
entrypoints_bpf__destroy ( skel ) ;
}
# define ATTACH_AND_STORE_LINK(__name) do { \
err = entrypoints_bpf__ # # __name # # __attach ( skel ) ; \
if ( err ) \
goto out ; \
\
links [ idx ] = bpf_link_get_from_fd ( skel - > links . __name # # _fd ) ; \
if ( IS_ERR ( links [ idx ] ) ) { \
err = PTR_ERR ( links [ idx ] ) ; \
goto out ; \
} \
\
/* Avoid taking over stdin/stdout/stderr of init process. Zeroing out \
* makes skel_closenz ( ) a no - op later in iterators_bpf__destroy ( ) . \
*/ \
close_fd ( skel - > links . __name # # _fd ) ; \
skel - > links . __name # # _fd = 0 ; \
idx + + ; \
} while ( 0 )
int hid_bpf_preload_skel ( void )
{
int err , idx = 0 ;
skel = entrypoints_bpf__open ( ) ;
if ( ! skel )
return - ENOMEM ;
err = entrypoints_bpf__load ( skel ) ;
if ( err )
goto out ;
jmp_table . map = bpf_map_get_with_uref ( skel - > maps . hid_jmp_table . map_fd ) ;
if ( IS_ERR ( jmp_table . map ) ) {
err = PTR_ERR ( jmp_table . map ) ;
goto out ;
}
ATTACH_AND_STORE_LINK ( hid_tail_call ) ;
return 0 ;
out :
hid_bpf_free_links_and_skel ( ) ;
return err ;
}