diff --git a/drivers/net/ethernet/netronome/nfp/bpf/jit.c b/drivers/net/ethernet/netronome/nfp/bpf/jit.c index 33111739b210..1d9e36835404 100644 --- a/drivers/net/ethernet/netronome/nfp/bpf/jit.c +++ b/drivers/net/ethernet/netronome/nfp/bpf/jit.c @@ -34,10 +34,11 @@ #define pr_fmt(fmt) "NFP net bpf: " fmt #include -#include #include #include +#include #include +#include #include #include "main.h" @@ -415,6 +416,60 @@ emit_alu(struct nfp_prog *nfp_prog, swreg dst, reg.dst_lmextn, reg.src_lmextn); } +static void +__emit_mul(struct nfp_prog *nfp_prog, enum alu_dst_ab dst_ab, u16 areg, + enum mul_type type, enum mul_step step, u16 breg, bool swap, + bool wr_both, bool dst_lmextn, bool src_lmextn) +{ + u64 insn; + + insn = OP_MUL_BASE | + FIELD_PREP(OP_MUL_A_SRC, areg) | + FIELD_PREP(OP_MUL_B_SRC, breg) | + FIELD_PREP(OP_MUL_STEP, step) | + FIELD_PREP(OP_MUL_DST_AB, dst_ab) | + FIELD_PREP(OP_MUL_SW, swap) | + FIELD_PREP(OP_MUL_TYPE, type) | + FIELD_PREP(OP_MUL_WR_AB, wr_both) | + FIELD_PREP(OP_MUL_SRC_LMEXTN, src_lmextn) | + FIELD_PREP(OP_MUL_DST_LMEXTN, dst_lmextn); + + nfp_prog_push(nfp_prog, insn); +} + +static void +emit_mul(struct nfp_prog *nfp_prog, swreg lreg, enum mul_type type, + enum mul_step step, swreg rreg) +{ + struct nfp_insn_ur_regs reg; + u16 areg; + int err; + + if (type == MUL_TYPE_START && step != MUL_STEP_NONE) { + nfp_prog->error = -EINVAL; + return; + } + + if (step == MUL_LAST || step == MUL_LAST_2) { + /* When type is step and step Number is LAST or LAST2, left + * source is used as destination. + */ + err = swreg_to_unrestricted(lreg, reg_none(), rreg, ®); + areg = reg.dst; + } else { + err = swreg_to_unrestricted(reg_none(), lreg, rreg, ®); + areg = reg.areg; + } + + if (err) { + nfp_prog->error = err; + return; + } + + __emit_mul(nfp_prog, reg.dst_ab, areg, type, step, reg.breg, reg.swap, + reg.wr_both, reg.dst_lmextn, reg.src_lmextn); +} + static void __emit_ld_field(struct nfp_prog *nfp_prog, enum shf_sc sc, u8 areg, u8 bmask, u8 breg, u8 shift, bool imm8, @@ -1380,6 +1435,133 @@ static void wrp_end32(struct nfp_prog *nfp_prog, swreg reg_in, u8 gpr_out) SHF_SC_R_ROT, 16); } +static void +wrp_mul_u32(struct nfp_prog *nfp_prog, swreg dst_hi, swreg dst_lo, swreg lreg, + swreg rreg, bool gen_high_half) +{ + emit_mul(nfp_prog, lreg, MUL_TYPE_START, MUL_STEP_NONE, rreg); + emit_mul(nfp_prog, lreg, MUL_TYPE_STEP_32x32, MUL_STEP_1, rreg); + emit_mul(nfp_prog, lreg, MUL_TYPE_STEP_32x32, MUL_STEP_2, rreg); + emit_mul(nfp_prog, lreg, MUL_TYPE_STEP_32x32, MUL_STEP_3, rreg); + emit_mul(nfp_prog, lreg, MUL_TYPE_STEP_32x32, MUL_STEP_4, rreg); + emit_mul(nfp_prog, dst_lo, MUL_TYPE_STEP_32x32, MUL_LAST, reg_none()); + if (gen_high_half) + emit_mul(nfp_prog, dst_hi, MUL_TYPE_STEP_32x32, MUL_LAST_2, + reg_none()); + else + wrp_immed(nfp_prog, dst_hi, 0); +} + +static void +wrp_mul_u16(struct nfp_prog *nfp_prog, swreg dst_hi, swreg dst_lo, swreg lreg, + swreg rreg) +{ + emit_mul(nfp_prog, lreg, MUL_TYPE_START, MUL_STEP_NONE, rreg); + emit_mul(nfp_prog, lreg, MUL_TYPE_STEP_16x16, MUL_STEP_1, rreg); + emit_mul(nfp_prog, lreg, MUL_TYPE_STEP_16x16, MUL_STEP_2, rreg); + emit_mul(nfp_prog, dst_lo, MUL_TYPE_STEP_16x16, MUL_LAST, reg_none()); +} + +static int +wrp_mul(struct nfp_prog *nfp_prog, struct nfp_insn_meta *meta, + bool gen_high_half, bool ropnd_from_reg) +{ + swreg multiplier, multiplicand, dst_hi, dst_lo; + const struct bpf_insn *insn = &meta->insn; + u32 lopnd_max, ropnd_max; + u8 dst_reg; + + dst_reg = insn->dst_reg; + multiplicand = reg_a(dst_reg * 2); + dst_hi = reg_both(dst_reg * 2 + 1); + dst_lo = reg_both(dst_reg * 2); + lopnd_max = meta->umax_dst; + if (ropnd_from_reg) { + multiplier = reg_b(insn->src_reg * 2); + ropnd_max = meta->umax_src; + } else { + u32 imm = insn->imm; + + multiplier = ur_load_imm_any(nfp_prog, imm, imm_b(nfp_prog)); + ropnd_max = imm; + } + if (lopnd_max > U16_MAX || ropnd_max > U16_MAX) + wrp_mul_u32(nfp_prog, dst_hi, dst_lo, multiplicand, multiplier, + gen_high_half); + else + wrp_mul_u16(nfp_prog, dst_hi, dst_lo, multiplicand, multiplier); + + return 0; +} + +static int wrp_div_imm(struct nfp_prog *nfp_prog, u8 dst, u64 imm) +{ + swreg dst_both = reg_both(dst), dst_a = reg_a(dst), dst_b = reg_a(dst); + struct reciprocal_value_adv rvalue; + u8 pre_shift, exp; + swreg magic; + + if (imm > U32_MAX) { + wrp_immed(nfp_prog, dst_both, 0); + return 0; + } + + /* NOTE: because we are using "reciprocal_value_adv" which doesn't + * support "divisor > (1u << 31)", we need to JIT separate NFP sequence + * to handle such case which actually equals to the result of unsigned + * comparison "dst >= imm" which could be calculated using the following + * NFP sequence: + * + * alu[--, dst, -, imm] + * immed[imm, 0] + * alu[dst, imm, +carry, 0] + * + */ + if (imm > 1U << 31) { + swreg tmp_b = ur_load_imm_any(nfp_prog, imm, imm_b(nfp_prog)); + + emit_alu(nfp_prog, reg_none(), dst_a, ALU_OP_SUB, tmp_b); + wrp_immed(nfp_prog, imm_a(nfp_prog), 0); + emit_alu(nfp_prog, dst_both, imm_a(nfp_prog), ALU_OP_ADD_C, + reg_imm(0)); + return 0; + } + + rvalue = reciprocal_value_adv(imm, 32); + exp = rvalue.exp; + if (rvalue.is_wide_m && !(imm & 1)) { + pre_shift = fls(imm & -imm) - 1; + rvalue = reciprocal_value_adv(imm >> pre_shift, 32 - pre_shift); + } else { + pre_shift = 0; + } + magic = ur_load_imm_any(nfp_prog, rvalue.m, imm_b(nfp_prog)); + if (imm == 1U << exp) { + emit_shf(nfp_prog, dst_both, reg_none(), SHF_OP_NONE, dst_b, + SHF_SC_R_SHF, exp); + } else if (rvalue.is_wide_m) { + wrp_mul_u32(nfp_prog, imm_both(nfp_prog), reg_none(), dst_a, + magic, true); + emit_alu(nfp_prog, dst_both, dst_a, ALU_OP_SUB, + imm_b(nfp_prog)); + emit_shf(nfp_prog, dst_both, reg_none(), SHF_OP_NONE, dst_b, + SHF_SC_R_SHF, 1); + emit_alu(nfp_prog, dst_both, dst_a, ALU_OP_ADD, + imm_b(nfp_prog)); + emit_shf(nfp_prog, dst_both, reg_none(), SHF_OP_NONE, dst_b, + SHF_SC_R_SHF, rvalue.sh - 1); + } else { + if (pre_shift) + emit_shf(nfp_prog, dst_both, reg_none(), SHF_OP_NONE, + dst_b, SHF_SC_R_SHF, pre_shift); + wrp_mul_u32(nfp_prog, dst_both, reg_none(), dst_a, magic, true); + emit_shf(nfp_prog, dst_both, reg_none(), SHF_OP_NONE, + dst_b, SHF_SC_R_SHF, rvalue.sh); + } + + return 0; +} + static int adjust_head(struct nfp_prog *nfp_prog, struct nfp_insn_meta *meta) { swreg tmp = imm_a(nfp_prog), tmp_len = imm_b(nfp_prog); @@ -1684,6 +1866,31 @@ static int sub_imm64(struct nfp_prog *nfp_prog, struct nfp_insn_meta *meta) return 0; } +static int mul_reg64(struct nfp_prog *nfp_prog, struct nfp_insn_meta *meta) +{ + return wrp_mul(nfp_prog, meta, true, true); +} + +static int mul_imm64(struct nfp_prog *nfp_prog, struct nfp_insn_meta *meta) +{ + return wrp_mul(nfp_prog, meta, true, false); +} + +static int div_imm64(struct nfp_prog *nfp_prog, struct nfp_insn_meta *meta) +{ + const struct bpf_insn *insn = &meta->insn; + + return wrp_div_imm(nfp_prog, insn->dst_reg * 2, insn->imm); +} + +static int div_reg64(struct nfp_prog *nfp_prog, struct nfp_insn_meta *meta) +{ + /* NOTE: verifier hook has rejected cases for which verifier doesn't + * know whether the source operand is constant or not. + */ + return wrp_div_imm(nfp_prog, meta->insn.dst_reg * 2, meta->umin_src); +} + static int neg_reg64(struct nfp_prog *nfp_prog, struct nfp_insn_meta *meta) { const struct bpf_insn *insn = &meta->insn; @@ -1772,8 +1979,8 @@ static int shl_reg64(struct nfp_prog *nfp_prog, struct nfp_insn_meta *meta) u8 dst, src; dst = insn->dst_reg * 2; - umin = meta->umin; - umax = meta->umax; + umin = meta->umin_src; + umax = meta->umax_src; if (umin == umax) return __shl_imm64(nfp_prog, dst, umin); @@ -1881,8 +2088,8 @@ static int shr_reg64(struct nfp_prog *nfp_prog, struct nfp_insn_meta *meta) u8 dst, src; dst = insn->dst_reg * 2; - umin = meta->umin; - umax = meta->umax; + umin = meta->umin_src; + umax = meta->umax_src; if (umin == umax) return __shr_imm64(nfp_prog, dst, umin); @@ -1995,8 +2202,8 @@ static int ashr_reg64(struct nfp_prog *nfp_prog, struct nfp_insn_meta *meta) u8 dst, src; dst = insn->dst_reg * 2; - umin = meta->umin; - umax = meta->umax; + umin = meta->umin_src; + umax = meta->umax_src; if (umin == umax) return __ashr_imm64(nfp_prog, dst, umin); @@ -2097,6 +2304,26 @@ static int sub_imm(struct nfp_prog *nfp_prog, struct nfp_insn_meta *meta) return wrp_alu32_imm(nfp_prog, meta, ALU_OP_SUB, !meta->insn.imm); } +static int mul_reg(struct nfp_prog *nfp_prog, struct nfp_insn_meta *meta) +{ + return wrp_mul(nfp_prog, meta, false, true); +} + +static int mul_imm(struct nfp_prog *nfp_prog, struct nfp_insn_meta *meta) +{ + return wrp_mul(nfp_prog, meta, false, false); +} + +static int div_reg(struct nfp_prog *nfp_prog, struct nfp_insn_meta *meta) +{ + return div_reg64(nfp_prog, meta); +} + +static int div_imm(struct nfp_prog *nfp_prog, struct nfp_insn_meta *meta) +{ + return div_imm64(nfp_prog, meta); +} + static int neg_reg(struct nfp_prog *nfp_prog, struct nfp_insn_meta *meta) { u8 dst = meta->insn.dst_reg * 2; @@ -2848,6 +3075,10 @@ static const instr_cb_t instr_cb[256] = { [BPF_ALU64 | BPF_ADD | BPF_K] = add_imm64, [BPF_ALU64 | BPF_SUB | BPF_X] = sub_reg64, [BPF_ALU64 | BPF_SUB | BPF_K] = sub_imm64, + [BPF_ALU64 | BPF_MUL | BPF_X] = mul_reg64, + [BPF_ALU64 | BPF_MUL | BPF_K] = mul_imm64, + [BPF_ALU64 | BPF_DIV | BPF_X] = div_reg64, + [BPF_ALU64 | BPF_DIV | BPF_K] = div_imm64, [BPF_ALU64 | BPF_NEG] = neg_reg64, [BPF_ALU64 | BPF_LSH | BPF_X] = shl_reg64, [BPF_ALU64 | BPF_LSH | BPF_K] = shl_imm64, @@ -2867,6 +3098,10 @@ static const instr_cb_t instr_cb[256] = { [BPF_ALU | BPF_ADD | BPF_K] = add_imm, [BPF_ALU | BPF_SUB | BPF_X] = sub_reg, [BPF_ALU | BPF_SUB | BPF_K] = sub_imm, + [BPF_ALU | BPF_MUL | BPF_X] = mul_reg, + [BPF_ALU | BPF_MUL | BPF_K] = mul_imm, + [BPF_ALU | BPF_DIV | BPF_X] = div_reg, + [BPF_ALU | BPF_DIV | BPF_K] = div_imm, [BPF_ALU | BPF_NEG] = neg_reg, [BPF_ALU | BPF_LSH | BPF_K] = shl_imm, [BPF_ALU | BPF_END | BPF_X] = end_reg32, diff --git a/drivers/net/ethernet/netronome/nfp/bpf/main.h b/drivers/net/ethernet/netronome/nfp/bpf/main.h index 654fe7823e5e..9845c1a2d4c2 100644 --- a/drivers/net/ethernet/netronome/nfp/bpf/main.h +++ b/drivers/net/ethernet/netronome/nfp/bpf/main.h @@ -263,8 +263,10 @@ struct nfp_bpf_reg_state { * @func_id: function id for call instructions * @arg1: arg1 for call instructions * @arg2: arg2 for call instructions - * @umin: copy of core verifier umin_value. - * @umax: copy of core verifier umax_value. + * @umin_src: copy of core verifier umin_value for src opearnd. + * @umax_src: copy of core verifier umax_value for src operand. + * @umin_dst: copy of core verifier umin_value for dst opearnd. + * @umax_dst: copy of core verifier umax_value for dst operand. * @off: index of first generated machine instruction (in nfp_prog.prog) * @n: eBPF instruction number * @flags: eBPF instruction extra optimization flags @@ -300,12 +302,15 @@ struct nfp_insn_meta { struct bpf_reg_state arg1; struct nfp_bpf_reg_state arg2; }; - /* We are interested in range info for some operands, - * for example, the shift amount. + /* We are interested in range info for operands of ALU + * operations. For example, shift amount, multiplicand and + * multiplier etc. */ struct { - u64 umin; - u64 umax; + u64 umin_src; + u64 umax_src; + u64 umin_dst; + u64 umax_dst; }; }; unsigned int off; @@ -339,6 +344,11 @@ static inline u8 mbpf_mode(const struct nfp_insn_meta *meta) return BPF_MODE(meta->insn.code); } +static inline bool is_mbpf_alu(const struct nfp_insn_meta *meta) +{ + return mbpf_class(meta) == BPF_ALU64 || mbpf_class(meta) == BPF_ALU; +} + static inline bool is_mbpf_load(const struct nfp_insn_meta *meta) { return (meta->insn.code & ~BPF_SIZE_MASK) == (BPF_LDX | BPF_MEM); @@ -384,23 +394,14 @@ static inline bool is_mbpf_xadd(const struct nfp_insn_meta *meta) return (meta->insn.code & ~BPF_SIZE_MASK) == (BPF_STX | BPF_XADD); } -static inline bool is_mbpf_indir_shift(const struct nfp_insn_meta *meta) +static inline bool is_mbpf_mul(const struct nfp_insn_meta *meta) { - u8 code = meta->insn.code; - bool is_alu, is_shift; - u8 opclass, opcode; + return is_mbpf_alu(meta) && mbpf_op(meta) == BPF_MUL; +} - opclass = BPF_CLASS(code); - is_alu = opclass == BPF_ALU64 || opclass == BPF_ALU; - if (!is_alu) - return false; - - opcode = BPF_OP(code); - is_shift = opcode == BPF_LSH || opcode == BPF_RSH || opcode == BPF_ARSH; - if (!is_shift) - return false; - - return BPF_SRC(code) == BPF_X; +static inline bool is_mbpf_div(const struct nfp_insn_meta *meta) +{ + return is_mbpf_alu(meta) && mbpf_op(meta) == BPF_DIV; } /** diff --git a/drivers/net/ethernet/netronome/nfp/bpf/offload.c b/drivers/net/ethernet/netronome/nfp/bpf/offload.c index 7eae4c0266f8..78f44c4d95b4 100644 --- a/drivers/net/ethernet/netronome/nfp/bpf/offload.c +++ b/drivers/net/ethernet/netronome/nfp/bpf/offload.c @@ -190,8 +190,10 @@ nfp_prog_prepare(struct nfp_prog *nfp_prog, const struct bpf_insn *prog, meta->insn = prog[i]; meta->n = i; - if (is_mbpf_indir_shift(meta)) - meta->umin = U64_MAX; + if (is_mbpf_alu(meta)) { + meta->umin_src = U64_MAX; + meta->umin_dst = U64_MAX; + } list_add_tail(&meta->l, &nfp_prog->insns); } diff --git a/drivers/net/ethernet/netronome/nfp/bpf/verifier.c b/drivers/net/ethernet/netronome/nfp/bpf/verifier.c index 4bfeba7b21b2..49ba0d645d36 100644 --- a/drivers/net/ethernet/netronome/nfp/bpf/verifier.c +++ b/drivers/net/ethernet/netronome/nfp/bpf/verifier.c @@ -516,6 +516,82 @@ nfp_bpf_check_xadd(struct nfp_prog *nfp_prog, struct nfp_insn_meta *meta, return nfp_bpf_check_ptr(nfp_prog, meta, env, meta->insn.dst_reg); } +static int +nfp_bpf_check_alu(struct nfp_prog *nfp_prog, struct nfp_insn_meta *meta, + struct bpf_verifier_env *env) +{ + const struct bpf_reg_state *sreg = + cur_regs(env) + meta->insn.src_reg; + const struct bpf_reg_state *dreg = + cur_regs(env) + meta->insn.dst_reg; + + meta->umin_src = min(meta->umin_src, sreg->umin_value); + meta->umax_src = max(meta->umax_src, sreg->umax_value); + meta->umin_dst = min(meta->umin_dst, dreg->umin_value); + meta->umax_dst = max(meta->umax_dst, dreg->umax_value); + + /* NFP supports u16 and u32 multiplication. + * + * For ALU64, if either operand is beyond u32's value range, we reject + * it. One thing to note, if the source operand is BPF_K, then we need + * to check "imm" field directly, and we'd reject it if it is negative. + * Because for ALU64, "imm" (with s32 type) is expected to be sign + * extended to s64 which NFP mul doesn't support. + * + * For ALU32, it is fine for "imm" be negative though, because the + * result is 32-bits and there is no difference on the low halve of + * the result for signed/unsigned mul, so we will get correct result. + */ + if (is_mbpf_mul(meta)) { + if (meta->umax_dst > U32_MAX) { + pr_vlog(env, "multiplier is not within u32 value range\n"); + return -EINVAL; + } + if (mbpf_src(meta) == BPF_X && meta->umax_src > U32_MAX) { + pr_vlog(env, "multiplicand is not within u32 value range\n"); + return -EINVAL; + } + if (mbpf_class(meta) == BPF_ALU64 && + mbpf_src(meta) == BPF_K && meta->insn.imm < 0) { + pr_vlog(env, "sign extended multiplicand won't be within u32 value range\n"); + return -EINVAL; + } + } + + /* NFP doesn't have divide instructions, we support divide by constant + * through reciprocal multiplication. Given NFP support multiplication + * no bigger than u32, we'd require divisor and dividend no bigger than + * that as well. + * + * Also eBPF doesn't support signed divide and has enforced this on C + * language level by failing compilation. However LLVM assembler hasn't + * enforced this, so it is possible for negative constant to leak in as + * a BPF_K operand through assembly code, we reject such cases as well. + */ + if (is_mbpf_div(meta)) { + if (meta->umax_dst > U32_MAX) { + pr_vlog(env, "dividend is not within u32 value range\n"); + return -EINVAL; + } + if (mbpf_src(meta) == BPF_X) { + if (meta->umin_src != meta->umax_src) { + pr_vlog(env, "divisor is not constant\n"); + return -EINVAL; + } + if (meta->umax_src > U32_MAX) { + pr_vlog(env, "divisor is not within u32 value range\n"); + return -EINVAL; + } + } + if (mbpf_src(meta) == BPF_K && meta->insn.imm < 0) { + pr_vlog(env, "divide by negative constant is not supported\n"); + return -EINVAL; + } + } + + return 0; +} + static int nfp_verify_insn(struct bpf_verifier_env *env, int insn_idx, int prev_insn_idx) { @@ -551,13 +627,8 @@ nfp_verify_insn(struct bpf_verifier_env *env, int insn_idx, int prev_insn_idx) if (is_mbpf_xadd(meta)) return nfp_bpf_check_xadd(nfp_prog, meta, env); - if (is_mbpf_indir_shift(meta)) { - const struct bpf_reg_state *sreg = - cur_regs(env) + meta->insn.src_reg; - - meta->umin = min(meta->umin, sreg->umin_value); - meta->umax = max(meta->umax, sreg->umax_value); - } + if (is_mbpf_alu(meta)) + return nfp_bpf_check_alu(nfp_prog, meta, env); return 0; } diff --git a/drivers/net/ethernet/netronome/nfp/nfp_asm.h b/drivers/net/ethernet/netronome/nfp/nfp_asm.h index f6677bc9875a..cdc4e065f6f5 100644 --- a/drivers/net/ethernet/netronome/nfp/nfp_asm.h +++ b/drivers/net/ethernet/netronome/nfp/nfp_asm.h @@ -426,4 +426,32 @@ static inline u32 nfp_get_ind_csr_ctx_ptr_offs(u32 read_offset) return (read_offset & ~NFP_IND_ME_CTX_PTR_BASE_MASK) | NFP_CSR_CTX_PTR; } +enum mul_type { + MUL_TYPE_START = 0x00, + MUL_TYPE_STEP_24x8 = 0x01, + MUL_TYPE_STEP_16x16 = 0x02, + MUL_TYPE_STEP_32x32 = 0x03, +}; + +enum mul_step { + MUL_STEP_1 = 0x00, + MUL_STEP_NONE = MUL_STEP_1, + MUL_STEP_2 = 0x01, + MUL_STEP_3 = 0x02, + MUL_STEP_4 = 0x03, + MUL_LAST = 0x04, + MUL_LAST_2 = 0x05, +}; + +#define OP_MUL_BASE 0x0f800000000ULL +#define OP_MUL_A_SRC 0x000000003ffULL +#define OP_MUL_B_SRC 0x000000ffc00ULL +#define OP_MUL_STEP 0x00000700000ULL +#define OP_MUL_DST_AB 0x00000800000ULL +#define OP_MUL_SW 0x00040000000ULL +#define OP_MUL_TYPE 0x00180000000ULL +#define OP_MUL_WR_AB 0x20000000000ULL +#define OP_MUL_SRC_LMEXTN 0x40000000000ULL +#define OP_MUL_DST_LMEXTN 0x80000000000ULL + #endif diff --git a/include/linux/reciprocal_div.h b/include/linux/reciprocal_div.h index e031e9f2f9d8..585ce89c0f33 100644 --- a/include/linux/reciprocal_div.h +++ b/include/linux/reciprocal_div.h @@ -25,6 +25,9 @@ struct reciprocal_value { u8 sh1, sh2; }; +/* "reciprocal_value" and "reciprocal_divide" together implement the basic + * version of the algorithm described in Figure 4.1 of the paper. + */ struct reciprocal_value reciprocal_value(u32 d); static inline u32 reciprocal_divide(u32 a, struct reciprocal_value R) @@ -33,4 +36,69 @@ static inline u32 reciprocal_divide(u32 a, struct reciprocal_value R) return (t + ((a - t) >> R.sh1)) >> R.sh2; } +struct reciprocal_value_adv { + u32 m; + u8 sh, exp; + bool is_wide_m; +}; + +/* "reciprocal_value_adv" implements the advanced version of the algorithm + * described in Figure 4.2 of the paper except when "divisor > (1U << 31)" whose + * ceil(log2(d)) result will be 32 which then requires u128 divide on host. The + * exception case could be easily handled before calling "reciprocal_value_adv". + * + * The advanced version requires more complex calculation to get the reciprocal + * multiplier and other control variables, but then could reduce the required + * emulation operations. + * + * It makes no sense to use this advanced version for host divide emulation, + * those extra complexities for calculating multiplier etc could completely + * waive our saving on emulation operations. + * + * However, it makes sense to use it for JIT divide code generation for which + * we are willing to trade performance of JITed code with that of host. As shown + * by the following pseudo code, the required emulation operations could go down + * from 6 (the basic version) to 3 or 4. + * + * To use the result of "reciprocal_value_adv", suppose we want to calculate + * n/d, the pseudo C code will be: + * + * struct reciprocal_value_adv rvalue; + * u8 pre_shift, exp; + * + * // handle exception case. + * if (d >= (1U << 31)) { + * result = n >= d; + * return; + * } + * + * rvalue = reciprocal_value_adv(d, 32) + * exp = rvalue.exp; + * if (rvalue.is_wide_m && !(d & 1)) { + * // floor(log2(d & (2^32 -d))) + * pre_shift = fls(d & -d) - 1; + * rvalue = reciprocal_value_adv(d >> pre_shift, 32 - pre_shift); + * } else { + * pre_shift = 0; + * } + * + * // code generation starts. + * if (imm == 1U << exp) { + * result = n >> exp; + * } else if (rvalue.is_wide_m) { + * // pre_shift must be zero when reached here. + * t = (n * rvalue.m) >> 32; + * result = n - t; + * result >>= 1; + * result += t; + * result >>= rvalue.sh - 1; + * } else { + * if (pre_shift) + * result = n >> pre_shift; + * result = ((u64)result * rvalue.m) >> 32; + * result >>= rvalue.sh; + * } + */ +struct reciprocal_value_adv reciprocal_value_adv(u32 d, u8 prec); + #endif /* _LINUX_RECIPROCAL_DIV_H */ diff --git a/lib/reciprocal_div.c b/lib/reciprocal_div.c index fcb4ce682c6f..bf043258fa00 100644 --- a/lib/reciprocal_div.c +++ b/lib/reciprocal_div.c @@ -1,4 +1,5 @@ // SPDX-License-Identifier: GPL-2.0 +#include #include #include #include @@ -26,3 +27,43 @@ struct reciprocal_value reciprocal_value(u32 d) return R; } EXPORT_SYMBOL(reciprocal_value); + +struct reciprocal_value_adv reciprocal_value_adv(u32 d, u8 prec) +{ + struct reciprocal_value_adv R; + u32 l, post_shift; + u64 mhigh, mlow; + + /* ceil(log2(d)) */ + l = fls(d - 1); + /* NOTE: mlow/mhigh could overflow u64 when l == 32. This case needs to + * be handled before calling "reciprocal_value_adv", please see the + * comment at include/linux/reciprocal_div.h. + */ + WARN(l == 32, + "ceil(log2(0x%08x)) == 32, %s doesn't support such divisor", + d, __func__); + post_shift = l; + mlow = 1ULL << (32 + l); + do_div(mlow, d); + mhigh = (1ULL << (32 + l)) + (1ULL << (32 + l - prec)); + do_div(mhigh, d); + + for (; post_shift > 0; post_shift--) { + u64 lo = mlow >> 1, hi = mhigh >> 1; + + if (lo >= hi) + break; + + mlow = lo; + mhigh = hi; + } + + R.m = (u32)mhigh; + R.sh = post_shift; + R.exp = l; + R.is_wide_m = mhigh > U32_MAX; + + return R; +} +EXPORT_SYMBOL(reciprocal_value_adv);