diff --git a/include/linux/btf.h b/include/linux/btf.h index d38aa4251c28..9ed00077db6e 100644 --- a/include/linux/btf.h +++ b/include/linux/btf.h @@ -487,6 +487,7 @@ const struct btf_member * btf_get_prog_ctx_type(struct bpf_verifier_log *log, const struct btf *btf, const struct btf_type *t, enum bpf_prog_type prog_type, int arg); +int get_kern_ctx_btf_id(struct bpf_verifier_log *log, enum bpf_prog_type prog_type); bool btf_types_are_same(const struct btf *btf1, u32 id1, const struct btf *btf2, u32 id2); #else @@ -531,6 +532,10 @@ btf_get_prog_ctx_type(struct bpf_verifier_log *log, const struct btf *btf, { return NULL; } +static inline int get_kern_ctx_btf_id(struct bpf_verifier_log *log, + enum bpf_prog_type prog_type) { + return -EINVAL; +} static inline bool btf_types_are_same(const struct btf *btf1, u32 id1, const struct btf *btf2, u32 id2) { diff --git a/kernel/bpf/btf.c b/kernel/bpf/btf.c index 1c78d4df9e18..1a59cc7ad730 100644 --- a/kernel/bpf/btf.c +++ b/kernel/bpf/btf.c @@ -5603,6 +5603,26 @@ static int btf_translate_to_vmlinux(struct bpf_verifier_log *log, return kern_ctx_type->type; } +int get_kern_ctx_btf_id(struct bpf_verifier_log *log, enum bpf_prog_type prog_type) +{ + const struct btf_member *kctx_member; + const struct btf_type *conv_struct; + const struct btf_type *kctx_type; + u32 kctx_type_id; + + conv_struct = bpf_ctx_convert.t; + /* get member for kernel ctx type */ + kctx_member = btf_type_member(conv_struct) + bpf_ctx_convert_map[prog_type] * 2 + 1; + kctx_type_id = kctx_member->type; + kctx_type = btf_type_by_id(btf_vmlinux, kctx_type_id); + if (!btf_type_is_struct(kctx_type)) { + bpf_log(log, "kern ctx type id %u is not a struct\n", kctx_type_id); + return -EINVAL; + } + + return kctx_type_id; +} + BTF_ID_LIST(bpf_ctx_convert_btf_id) BTF_ID(struct, bpf_ctx_convert) diff --git a/kernel/bpf/helpers.c b/kernel/bpf/helpers.c index 8e0d80b6bf09..e45e72676f5a 100644 --- a/kernel/bpf/helpers.c +++ b/kernel/bpf/helpers.c @@ -1879,6 +1879,11 @@ void bpf_task_release(struct task_struct *p) put_task_struct_rcu_user(p); } +void *bpf_cast_to_kern_ctx(void *obj) +{ + return obj; +} + __diag_pop(); BTF_SET8_START(generic_btf_ids) @@ -1907,6 +1912,7 @@ BTF_ID(struct, task_struct) BTF_ID(func, bpf_task_release) BTF_SET8_START(common_btf_ids) +BTF_ID_FLAGS(func, bpf_cast_to_kern_ctx) BTF_SET8_END(common_btf_ids) static const struct btf_kfunc_id_set common_kfunc_set = { diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c index eb090af35477..d63cc0fb0a9c 100644 --- a/kernel/bpf/verifier.c +++ b/kernel/bpf/verifier.c @@ -7907,6 +7907,7 @@ struct bpf_kfunc_call_arg_meta { u32 ref_obj_id; u8 release_regno; bool r0_rdonly; + u32 ret_btf_id; u64 r0_size; struct { u64 value; @@ -8151,6 +8152,7 @@ enum special_kfunc_type { KF_bpf_list_push_back, KF_bpf_list_pop_front, KF_bpf_list_pop_back, + KF_bpf_cast_to_kern_ctx, }; BTF_SET_START(special_kfunc_set) @@ -8160,6 +8162,7 @@ BTF_ID(func, bpf_list_push_front) BTF_ID(func, bpf_list_push_back) BTF_ID(func, bpf_list_pop_front) BTF_ID(func, bpf_list_pop_back) +BTF_ID(func, bpf_cast_to_kern_ctx) BTF_SET_END(special_kfunc_set) BTF_ID_LIST(special_kfunc_list) @@ -8169,6 +8172,7 @@ BTF_ID(func, bpf_list_push_front) BTF_ID(func, bpf_list_push_back) BTF_ID(func, bpf_list_pop_front) BTF_ID(func, bpf_list_pop_back) +BTF_ID(func, bpf_cast_to_kern_ctx) static enum kfunc_ptr_arg_type get_kfunc_ptr_arg_type(struct bpf_verifier_env *env, @@ -8182,6 +8186,9 @@ get_kfunc_ptr_arg_type(struct bpf_verifier_env *env, struct bpf_reg_state *reg = ®s[regno]; bool arg_mem_size = false; + if (meta->func_id == special_kfunc_list[KF_bpf_cast_to_kern_ctx]) + return KF_ARG_PTR_TO_CTX; + /* In this function, we verify the kfunc's BTF as per the argument type, * leaving the rest of the verification with respect to the register * type to our caller. When a set of conditions hold in the BTF type of @@ -8668,6 +8675,13 @@ static int check_kfunc_args(struct bpf_verifier_env *env, struct bpf_kfunc_call_ verbose(env, "arg#%d expected pointer to ctx, but got %s\n", i, btf_type_str(t)); return -EINVAL; } + + if (meta->func_id == special_kfunc_list[KF_bpf_cast_to_kern_ctx]) { + ret = get_kern_ctx_btf_id(&env->log, resolve_prog_type(env->prog)); + if (ret < 0) + return -EINVAL; + meta->ret_btf_id = ret; + } break; case KF_ARG_PTR_TO_ALLOC_BTF_ID: if (reg->type != (PTR_TO_BTF_ID | MEM_ALLOC)) { @@ -8922,6 +8936,11 @@ static int check_kfunc_call(struct bpf_verifier_env *env, struct bpf_insn *insn, regs[BPF_REG_0].btf = field->list_head.btf; regs[BPF_REG_0].btf_id = field->list_head.value_btf_id; regs[BPF_REG_0].off = field->list_head.node_offset; + } else if (meta.func_id == special_kfunc_list[KF_bpf_cast_to_kern_ctx]) { + mark_reg_known_zero(env, regs, BPF_REG_0); + regs[BPF_REG_0].type = PTR_TO_BTF_ID | PTR_TRUSTED; + regs[BPF_REG_0].btf = desc_btf; + regs[BPF_REG_0].btf_id = meta.ret_btf_id; } else { verbose(env, "kernel function %s unhandled dynamic return type\n", meta.func_name); @@ -15175,6 +15194,9 @@ static int fixup_kfunc_call(struct bpf_verifier_env *env, struct bpf_insn *insn, insn_buf[1] = addr[1]; insn_buf[2] = *insn; *cnt = 3; + } else if (desc->func_id == special_kfunc_list[KF_bpf_cast_to_kern_ctx]) { + insn_buf[0] = BPF_MOV64_REG(BPF_REG_0, BPF_REG_1); + *cnt = 1; } return 0; }