diff --git a/include/linux/tnum.h b/include/linux/tnum.h index 1c3948a1d6ad..3c13240077b8 100644 --- a/include/linux/tnum.h +++ b/include/linux/tnum.h @@ -106,6 +106,10 @@ int tnum_sbin(char *str, size_t size, struct tnum a); struct tnum tnum_subreg(struct tnum a); /* Returns the tnum with the lower 32-bit subreg cleared */ struct tnum tnum_clear_subreg(struct tnum a); +/* Returns the tnum with the lower 32-bit subreg in *reg* set to the lower + * 32-bit subreg in *subreg* + */ +struct tnum tnum_with_subreg(struct tnum reg, struct tnum subreg); /* Returns the tnum with the lower 32-bit subreg set to value */ struct tnum tnum_const_subreg(struct tnum a, u32 value); /* Returns true if 32-bit subreg @a is a known constant*/ diff --git a/kernel/bpf/tnum.c b/kernel/bpf/tnum.c index 3d7127f439a1..f4c91c9b27d7 100644 --- a/kernel/bpf/tnum.c +++ b/kernel/bpf/tnum.c @@ -208,7 +208,12 @@ struct tnum tnum_clear_subreg(struct tnum a) return tnum_lshift(tnum_rshift(a, 32), 32); } +struct tnum tnum_with_subreg(struct tnum reg, struct tnum subreg) +{ + return tnum_or(tnum_clear_subreg(reg), tnum_subreg(subreg)); +} + struct tnum tnum_const_subreg(struct tnum a, u32 value) { - return tnum_or(tnum_clear_subreg(a), tnum_const(value)); + return tnum_with_subreg(a, tnum_const(value)); } diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c index 9ae6eae13471..39ce141c55d3 100644 --- a/kernel/bpf/verifier.c +++ b/kernel/bpf/verifier.c @@ -14453,6 +14453,158 @@ static int is_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg return is_scalar_branch_taken(reg1, reg2, opcode, is_jmp32); } +/* Opcode that corresponds to a *false* branch condition. + * E.g., if r1 < r2, then reverse (false) condition is r1 >= r2 + */ +static u8 rev_opcode(u8 opcode) +{ + switch (opcode) { + case BPF_JEQ: return BPF_JNE; + case BPF_JNE: return BPF_JEQ; + /* JSET doesn't have it's reverse opcode in BPF, so add + * BPF_X flag to denote the reverse of that operation + */ + case BPF_JSET: return BPF_JSET | BPF_X; + case BPF_JSET | BPF_X: return BPF_JSET; + case BPF_JGE: return BPF_JLT; + case BPF_JGT: return BPF_JLE; + case BPF_JLE: return BPF_JGT; + case BPF_JLT: return BPF_JGE; + case BPF_JSGE: return BPF_JSLT; + case BPF_JSGT: return BPF_JSLE; + case BPF_JSLE: return BPF_JSGT; + case BPF_JSLT: return BPF_JSGE; + default: return 0; + } +} + +/* Refine range knowledge for 2 conditional operation. */ +static void regs_refine_cond_op(struct bpf_reg_state *reg1, struct bpf_reg_state *reg2, + u8 opcode, bool is_jmp32) +{ + struct tnum t; + u64 val; + +again: + switch (opcode) { + case BPF_JEQ: + if (is_jmp32) { + reg1->u32_min_value = max(reg1->u32_min_value, reg2->u32_min_value); + reg1->u32_max_value = min(reg1->u32_max_value, reg2->u32_max_value); + reg1->s32_min_value = max(reg1->s32_min_value, reg2->s32_min_value); + reg1->s32_max_value = min(reg1->s32_max_value, reg2->s32_max_value); + reg2->u32_min_value = reg1->u32_min_value; + reg2->u32_max_value = reg1->u32_max_value; + reg2->s32_min_value = reg1->s32_min_value; + reg2->s32_max_value = reg1->s32_max_value; + + t = tnum_intersect(tnum_subreg(reg1->var_off), tnum_subreg(reg2->var_off)); + reg1->var_off = tnum_with_subreg(reg1->var_off, t); + reg2->var_off = tnum_with_subreg(reg2->var_off, t); + } else { + reg1->umin_value = max(reg1->umin_value, reg2->umin_value); + reg1->umax_value = min(reg1->umax_value, reg2->umax_value); + reg1->smin_value = max(reg1->smin_value, reg2->smin_value); + reg1->smax_value = min(reg1->smax_value, reg2->smax_value); + reg2->umin_value = reg1->umin_value; + reg2->umax_value = reg1->umax_value; + reg2->smin_value = reg1->smin_value; + reg2->smax_value = reg1->smax_value; + + reg1->var_off = tnum_intersect(reg1->var_off, reg2->var_off); + reg2->var_off = reg1->var_off; + } + break; + case BPF_JNE: + /* we don't derive any new information for inequality yet */ + break; + case BPF_JSET: + if (!is_reg_const(reg2, is_jmp32)) + swap(reg1, reg2); + if (!is_reg_const(reg2, is_jmp32)) + break; + val = reg_const_value(reg2, is_jmp32); + /* BPF_JSET (i.e., TRUE branch, *not* BPF_JSET | BPF_X) + * requires single bit to learn something useful. E.g., if we + * know that `r1 & 0x3` is true, then which bits (0, 1, or both) + * are actually set? We can learn something definite only if + * it's a single-bit value to begin with. + * + * BPF_JSET | BPF_X (i.e., negation of BPF_JSET) doesn't have + * this restriction. I.e., !(r1 & 0x3) means neither bit 0 nor + * bit 1 is set, which we can readily use in adjustments. + */ + if (!is_power_of_2(val)) + break; + if (is_jmp32) { + t = tnum_or(tnum_subreg(reg1->var_off), tnum_const(val)); + reg1->var_off = tnum_with_subreg(reg1->var_off, t); + } else { + reg1->var_off = tnum_or(reg1->var_off, tnum_const(val)); + } + break; + case BPF_JSET | BPF_X: /* reverse of BPF_JSET, see rev_opcode() */ + if (!is_reg_const(reg2, is_jmp32)) + swap(reg1, reg2); + if (!is_reg_const(reg2, is_jmp32)) + break; + val = reg_const_value(reg2, is_jmp32); + if (is_jmp32) { + t = tnum_and(tnum_subreg(reg1->var_off), tnum_const(~val)); + reg1->var_off = tnum_with_subreg(reg1->var_off, t); + } else { + reg1->var_off = tnum_and(reg1->var_off, tnum_const(~val)); + } + break; + case BPF_JLE: + if (is_jmp32) { + reg1->u32_max_value = min(reg1->u32_max_value, reg2->u32_max_value); + reg2->u32_min_value = max(reg1->u32_min_value, reg2->u32_min_value); + } else { + reg1->umax_value = min(reg1->umax_value, reg2->umax_value); + reg2->umin_value = max(reg1->umin_value, reg2->umin_value); + } + break; + case BPF_JLT: + if (is_jmp32) { + reg1->u32_max_value = min(reg1->u32_max_value, reg2->u32_max_value - 1); + reg2->u32_min_value = max(reg1->u32_min_value + 1, reg2->u32_min_value); + } else { + reg1->umax_value = min(reg1->umax_value, reg2->umax_value - 1); + reg2->umin_value = max(reg1->umin_value + 1, reg2->umin_value); + } + break; + case BPF_JSLE: + if (is_jmp32) { + reg1->s32_max_value = min(reg1->s32_max_value, reg2->s32_max_value); + reg2->s32_min_value = max(reg1->s32_min_value, reg2->s32_min_value); + } else { + reg1->smax_value = min(reg1->smax_value, reg2->smax_value); + reg2->smin_value = max(reg1->smin_value, reg2->smin_value); + } + break; + case BPF_JSLT: + if (is_jmp32) { + reg1->s32_max_value = min(reg1->s32_max_value, reg2->s32_max_value - 1); + reg2->s32_min_value = max(reg1->s32_min_value + 1, reg2->s32_min_value); + } else { + reg1->smax_value = min(reg1->smax_value, reg2->smax_value - 1); + reg2->smin_value = max(reg1->smin_value + 1, reg2->smin_value); + } + break; + case BPF_JGE: + case BPF_JGT: + case BPF_JSGE: + case BPF_JSGT: + /* just reuse LE/LT logic above */ + opcode = flip_opcode(opcode); + swap(reg1, reg2); + goto again; + default: + return; + } +} + /* Adjusts the register min/max values in the case that the dst_reg and * src_reg are both SCALAR_VALUE registers (or we are simply doing a BPF_K * check, in which case we havea fake SCALAR_VALUE representing insn->imm). @@ -14465,13 +14617,6 @@ static void reg_set_min_max(struct bpf_reg_state *true_reg1, struct bpf_reg_state *false_reg2, u8 opcode, bool is_jmp32) { - struct tnum false_32off, false_64off; - struct tnum true_32off, true_64off; - u64 uval; - u32 uval32; - s64 sval; - s32 sval32; - /* If either register is a pointer, we can't learn anything about its * variable offset from the compare (unless they were a pointer into * the same object, but we don't bother with that). @@ -14479,192 +14624,15 @@ static void reg_set_min_max(struct bpf_reg_state *true_reg1, if (false_reg1->type != SCALAR_VALUE || false_reg2->type != SCALAR_VALUE) return; - /* we expect right-hand registers (src ones) to be constants, for now */ - if (!is_reg_const(false_reg2, is_jmp32)) { - opcode = flip_opcode(opcode); - swap(true_reg1, true_reg2); - swap(false_reg1, false_reg2); - } - if (!is_reg_const(false_reg2, is_jmp32)) - return; + /* fallthrough (FALSE) branch */ + regs_refine_cond_op(false_reg1, false_reg2, rev_opcode(opcode), is_jmp32); + reg_bounds_sync(false_reg1); + reg_bounds_sync(false_reg2); - false_32off = tnum_subreg(false_reg1->var_off); - false_64off = false_reg1->var_off; - true_32off = tnum_subreg(true_reg1->var_off); - true_64off = true_reg1->var_off; - uval = false_reg2->var_off.value; - uval32 = (u32)tnum_subreg(false_reg2->var_off).value; - sval = (s64)uval; - sval32 = (s32)uval32; - - switch (opcode) { - /* JEQ/JNE comparison doesn't change the register equivalence. - * - * r1 = r2; - * if (r1 == 42) goto label; - * ... - * label: // here both r1 and r2 are known to be 42. - * - * Hence when marking register as known preserve it's ID. - */ - case BPF_JEQ: - if (is_jmp32) { - __mark_reg32_known(true_reg1, uval32); - true_32off = tnum_subreg(true_reg1->var_off); - } else { - ___mark_reg_known(true_reg1, uval); - true_64off = true_reg1->var_off; - } - break; - case BPF_JNE: - if (is_jmp32) { - __mark_reg32_known(false_reg1, uval32); - false_32off = tnum_subreg(false_reg1->var_off); - } else { - ___mark_reg_known(false_reg1, uval); - false_64off = false_reg1->var_off; - } - break; - case BPF_JSET: - if (is_jmp32) { - false_32off = tnum_and(false_32off, tnum_const(~uval32)); - if (is_power_of_2(uval32)) - true_32off = tnum_or(true_32off, - tnum_const(uval32)); - } else { - false_64off = tnum_and(false_64off, tnum_const(~uval)); - if (is_power_of_2(uval)) - true_64off = tnum_or(true_64off, - tnum_const(uval)); - } - break; - case BPF_JGE: - case BPF_JGT: - { - if (is_jmp32) { - u32 false_umax = opcode == BPF_JGT ? uval32 : uval32 - 1; - u32 true_umin = opcode == BPF_JGT ? uval32 + 1 : uval32; - - false_reg1->u32_max_value = min(false_reg1->u32_max_value, - false_umax); - true_reg1->u32_min_value = max(true_reg1->u32_min_value, - true_umin); - } else { - u64 false_umax = opcode == BPF_JGT ? uval : uval - 1; - u64 true_umin = opcode == BPF_JGT ? uval + 1 : uval; - - false_reg1->umax_value = min(false_reg1->umax_value, false_umax); - true_reg1->umin_value = max(true_reg1->umin_value, true_umin); - } - break; - } - case BPF_JSGE: - case BPF_JSGT: - { - if (is_jmp32) { - s32 false_smax = opcode == BPF_JSGT ? sval32 : sval32 - 1; - s32 true_smin = opcode == BPF_JSGT ? sval32 + 1 : sval32; - - false_reg1->s32_max_value = min(false_reg1->s32_max_value, false_smax); - true_reg1->s32_min_value = max(true_reg1->s32_min_value, true_smin); - } else { - s64 false_smax = opcode == BPF_JSGT ? sval : sval - 1; - s64 true_smin = opcode == BPF_JSGT ? sval + 1 : sval; - - false_reg1->smax_value = min(false_reg1->smax_value, false_smax); - true_reg1->smin_value = max(true_reg1->smin_value, true_smin); - } - break; - } - case BPF_JLE: - case BPF_JLT: - { - if (is_jmp32) { - u32 false_umin = opcode == BPF_JLT ? uval32 : uval32 + 1; - u32 true_umax = opcode == BPF_JLT ? uval32 - 1 : uval32; - - false_reg1->u32_min_value = max(false_reg1->u32_min_value, - false_umin); - true_reg1->u32_max_value = min(true_reg1->u32_max_value, - true_umax); - } else { - u64 false_umin = opcode == BPF_JLT ? uval : uval + 1; - u64 true_umax = opcode == BPF_JLT ? uval - 1 : uval; - - false_reg1->umin_value = max(false_reg1->umin_value, false_umin); - true_reg1->umax_value = min(true_reg1->umax_value, true_umax); - } - break; - } - case BPF_JSLE: - case BPF_JSLT: - { - if (is_jmp32) { - s32 false_smin = opcode == BPF_JSLT ? sval32 : sval32 + 1; - s32 true_smax = opcode == BPF_JSLT ? sval32 - 1 : sval32; - - false_reg1->s32_min_value = max(false_reg1->s32_min_value, false_smin); - true_reg1->s32_max_value = min(true_reg1->s32_max_value, true_smax); - } else { - s64 false_smin = opcode == BPF_JSLT ? sval : sval + 1; - s64 true_smax = opcode == BPF_JSLT ? sval - 1 : sval; - - false_reg1->smin_value = max(false_reg1->smin_value, false_smin); - true_reg1->smax_value = min(true_reg1->smax_value, true_smax); - } - break; - } - default: - return; - } - - if (is_jmp32) { - false_reg1->var_off = tnum_or(tnum_clear_subreg(false_64off), - tnum_subreg(false_32off)); - true_reg1->var_off = tnum_or(tnum_clear_subreg(true_64off), - tnum_subreg(true_32off)); - reg_bounds_sync(false_reg1); - reg_bounds_sync(true_reg1); - } else { - false_reg1->var_off = false_64off; - true_reg1->var_off = true_64off; - reg_bounds_sync(false_reg1); - reg_bounds_sync(true_reg1); - } -} - -/* Regs are known to be equal, so intersect their min/max/var_off */ -static void __reg_combine_min_max(struct bpf_reg_state *src_reg, - struct bpf_reg_state *dst_reg) -{ - src_reg->umin_value = dst_reg->umin_value = max(src_reg->umin_value, - dst_reg->umin_value); - src_reg->umax_value = dst_reg->umax_value = min(src_reg->umax_value, - dst_reg->umax_value); - src_reg->smin_value = dst_reg->smin_value = max(src_reg->smin_value, - dst_reg->smin_value); - src_reg->smax_value = dst_reg->smax_value = min(src_reg->smax_value, - dst_reg->smax_value); - src_reg->var_off = dst_reg->var_off = tnum_intersect(src_reg->var_off, - dst_reg->var_off); - reg_bounds_sync(src_reg); - reg_bounds_sync(dst_reg); -} - -static void reg_combine_min_max(struct bpf_reg_state *true_src, - struct bpf_reg_state *true_dst, - struct bpf_reg_state *false_src, - struct bpf_reg_state *false_dst, - u8 opcode) -{ - switch (opcode) { - case BPF_JEQ: - __reg_combine_min_max(true_src, true_dst); - break; - case BPF_JNE: - __reg_combine_min_max(false_src, false_dst); - break; - } + /* jump (TRUE) branch */ + regs_refine_cond_op(true_reg1, true_reg2, opcode, is_jmp32); + reg_bounds_sync(true_reg1); + reg_bounds_sync(true_reg2); } static void mark_ptr_or_null_reg(struct bpf_func_state *state, @@ -14961,22 +14929,12 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env, reg_set_min_max(&other_branch_regs[insn->dst_reg], &other_branch_regs[insn->src_reg], dst_reg, src_reg, opcode, is_jmp32); - - if (dst_reg->type == SCALAR_VALUE && - src_reg->type == SCALAR_VALUE && - !is_jmp32 && (opcode == BPF_JEQ || opcode == BPF_JNE)) { - /* Comparing for equality, we can combine knowledge */ - reg_combine_min_max(&other_branch_regs[insn->src_reg], - &other_branch_regs[insn->dst_reg], - src_reg, dst_reg, opcode); - } } else /* BPF_SRC(insn->code) == BPF_K */ { reg_set_min_max(&other_branch_regs[insn->dst_reg], src_reg /* fake one */, dst_reg, src_reg /* same fake one */, opcode, is_jmp32); } - if (BPF_SRC(insn->code) == BPF_X && src_reg->type == SCALAR_VALUE && src_reg->id && !WARN_ON_ONCE(src_reg->id != other_branch_regs[insn->src_reg].id)) {