diff --git a/drivers/infiniband/core/uverbs_cmd.c b/drivers/infiniband/core/uverbs_cmd.c index c398d1a64614..d413dafb9211 100644 --- a/drivers/infiniband/core/uverbs_cmd.c +++ b/drivers/infiniband/core/uverbs_cmd.c @@ -3031,12 +3031,29 @@ static int ib_uverbs_ex_modify_wq(struct uverbs_attr_bundle *attrs) if (!wq) return -EINVAL; - wq_attr.curr_wq_state = cmd.curr_wq_state; - wq_attr.wq_state = cmd.wq_state; if (cmd.attr_mask & IB_WQ_FLAGS) { wq_attr.flags = cmd.flags; wq_attr.flags_mask = cmd.flags_mask; } + + if (cmd.attr_mask & IB_WQ_CUR_STATE) { + if (cmd.curr_wq_state > IB_WQS_ERR) + return -EINVAL; + + wq_attr.curr_wq_state = cmd.curr_wq_state; + } else { + wq_attr.curr_wq_state = wq->state; + } + + if (cmd.attr_mask & IB_WQ_STATE) { + if (cmd.wq_state > IB_WQS_ERR) + return -EINVAL; + + wq_attr.wq_state = cmd.wq_state; + } else { + wq_attr.wq_state = wq_attr.curr_wq_state; + } + ret = wq->device->ops.modify_wq(wq, &wq_attr, cmd.attr_mask, &attrs->driver_udata); uobj_put_obj_read(wq); diff --git a/drivers/infiniband/hw/mlx4/qp.c b/drivers/infiniband/hw/mlx4/qp.c index 6e2b3e2f83f1..17ce928e41bd 100644 --- a/drivers/infiniband/hw/mlx4/qp.c +++ b/drivers/infiniband/hw/mlx4/qp.c @@ -4294,13 +4294,8 @@ int mlx4_ib_modify_wq(struct ib_wq *ibwq, struct ib_wq_attr *wq_attr, if (wq_attr_mask & IB_WQ_FLAGS) return -EOPNOTSUPP; - cur_state = wq_attr_mask & IB_WQ_CUR_STATE ? wq_attr->curr_wq_state : - ibwq->state; - new_state = wq_attr_mask & IB_WQ_STATE ? wq_attr->wq_state : cur_state; - - if (cur_state < IB_WQS_RESET || cur_state > IB_WQS_ERR || - new_state < IB_WQS_RESET || new_state > IB_WQS_ERR) - return -EINVAL; + cur_state = wq_attr->curr_wq_state; + new_state = wq_attr->wq_state; if ((new_state == IB_WQS_RDY) && (cur_state == IB_WQS_ERR)) return -EINVAL; diff --git a/drivers/infiniband/hw/mlx5/qp.c b/drivers/infiniband/hw/mlx5/qp.c index 09e29c6cb66d..4540835e05bd 100644 --- a/drivers/infiniband/hw/mlx5/qp.c +++ b/drivers/infiniband/hw/mlx5/qp.c @@ -6317,10 +6317,8 @@ int mlx5_ib_modify_wq(struct ib_wq *wq, struct ib_wq_attr *wq_attr, rqc = MLX5_ADDR_OF(modify_rq_in, in, ctx); - curr_wq_state = (wq_attr_mask & IB_WQ_CUR_STATE) ? - wq_attr->curr_wq_state : wq->state; - wq_state = (wq_attr_mask & IB_WQ_STATE) ? - wq_attr->wq_state : curr_wq_state; + curr_wq_state = wq_attr->curr_wq_state; + wq_state = wq_attr->wq_state; if (curr_wq_state == IB_WQS_ERR) curr_wq_state = MLX5_RQC_STATE_ERR; if (wq_state == IB_WQS_ERR)