diff --git a/drivers/infiniband/hw/hns/hns_roce_hw_v2.c b/drivers/infiniband/hw/hns/hns_roce_hw_v2.c index 7eba9b5bfcb8..d5a63e4c3adf 100644 --- a/drivers/infiniband/hw/hns/hns_roce_hw_v2.c +++ b/drivers/infiniband/hw/hns/hns_roce_hw_v2.c @@ -869,13 +869,32 @@ static void hns_roce_free_srq_wqe(struct hns_roce_srq *srq, u32 wqe_index) spin_unlock(&srq->lock); } -int hns_roce_srqwq_overflow(struct hns_roce_srq *srq, u32 nreq) +static int hns_roce_srqwq_overflow(struct hns_roce_srq *srq) { struct hns_roce_idx_que *idx_que = &srq->idx_que; - unsigned int cur; - cur = idx_que->head - idx_que->tail; - return cur + nreq >= srq->wqe_cnt; + return idx_que->head - idx_que->tail >= srq->wqe_cnt; +} + +static int check_post_srq_valid(struct hns_roce_srq *srq, u32 max_sge, + const struct ib_recv_wr *wr) +{ + struct ib_device *ib_dev = srq->ibsrq.device; + + if (unlikely(wr->num_sge > max_sge)) { + ibdev_err(ib_dev, + "failed to check sge, wr->num_sge = %d, max_sge = %u.\n", + wr->num_sge, max_sge); + return -EINVAL; + } + + if (unlikely(hns_roce_srqwq_overflow(srq))) { + ibdev_err(ib_dev, + "failed to check srqwq status, srqwq is full.\n"); + return -ENOMEM; + } + + return 0; } static int get_srq_wqe_idx(struct hns_roce_srq *srq, u32 *wqe_idx) @@ -892,36 +911,40 @@ static int get_srq_wqe_idx(struct hns_roce_srq *srq, u32 *wqe_idx) return 0; } +static void fill_wqe_idx(struct hns_roce_srq *srq, unsigned int wqe_idx) +{ + struct hns_roce_idx_que *idx_que = &srq->idx_que; + unsigned int head; + __le32 *buf; + + head = idx_que->head & (srq->wqe_cnt - 1); + + buf = get_idx_buf(idx_que, head); + *buf = cpu_to_le32(wqe_idx); + + idx_que->head++; +} + static int hns_roce_v2_post_srq_recv(struct ib_srq *ibsrq, const struct ib_recv_wr *wr, const struct ib_recv_wr **bad_wr) { struct hns_roce_dev *hr_dev = to_hr_dev(ibsrq->device); struct hns_roce_srq *srq = to_hr_srq(ibsrq); - u32 wqe_idx, ind, nreq, max_sge; struct hns_roce_v2_db srq_db; unsigned long flags; - __le32 *srq_idx; int ret = 0; + u32 max_sge; + u32 wqe_idx; void *wqe; + u32 nreq; spin_lock_irqsave(&srq->lock, flags); - ind = srq->idx_que.head & (srq->wqe_cnt - 1); max_sge = srq->max_gs - srq->rsv_sge; - for (nreq = 0; wr; ++nreq, wr = wr->next) { - if (unlikely(wr->num_sge > max_sge)) { - ibdev_err(&hr_dev->ib_dev, - "srq: num_sge = %d, max_sge = %u.\n", - wr->num_sge, max_sge); - ret = -EINVAL; - *bad_wr = wr; - break; - } - - if (unlikely(hns_roce_srqwq_overflow(srq, nreq))) { - ret = -ENOMEM; + ret = check_post_srq_valid(srq, max_sge, wr); + if (ret) { *bad_wr = wr; break; } @@ -934,17 +957,11 @@ static int hns_roce_v2_post_srq_recv(struct ib_srq *ibsrq, wqe = get_srq_wqe_buf(srq, wqe_idx); fill_recv_sge_to_wqe(wr, wqe, max_sge, srq->rsv_sge); - - srq_idx = get_idx_buf(&srq->idx_que, ind); - *srq_idx = cpu_to_le32(wqe_idx); - + fill_wqe_idx(srq, wqe_idx); srq->wrid[wqe_idx] = wr->wr_id; - ind = (ind + 1) & (srq->wqe_cnt - 1); } if (likely(nreq)) { - srq->idx_que.head += nreq; - /* * Make sure that descriptors are written before * doorbell record.