diff --git a/drivers/infiniband/hw/mlx5/qp.c b/drivers/infiniband/hw/mlx5/qp.c
index 00cffbfe6c35..5bdf20c676fe 100644
--- a/drivers/infiniband/hw/mlx5/qp.c
+++ b/drivers/infiniband/hw/mlx5/qp.c
@@ -3071,7 +3071,7 @@ static void set_linv_umr_seg(struct mlx5_wqe_umr_ctrl_seg *umr)
 	umr->flags = MLX5_UMR_INLINE;
 }
 
-static __be64 get_umr_reg_mr_mask(void)
+static __be64 get_umr_reg_mr_mask(int atomic)
 {
 	u64 result;
 
@@ -3084,9 +3084,11 @@ static __be64 get_umr_reg_mr_mask(void)
 		 MLX5_MKEY_MASK_KEY		|
 		 MLX5_MKEY_MASK_RR		|
 		 MLX5_MKEY_MASK_RW		|
-		 MLX5_MKEY_MASK_A		|
 		 MLX5_MKEY_MASK_FREE;
 
+	if (atomic)
+		result |= MLX5_MKEY_MASK_A;
+
 	return cpu_to_be64(result);
 }
 
@@ -3147,7 +3149,7 @@ static __be64 get_umr_update_pd_mask(void)
 }
 
 static void set_reg_umr_segment(struct mlx5_wqe_umr_ctrl_seg *umr,
-				struct ib_send_wr *wr)
+				struct ib_send_wr *wr, int atomic)
 {
 	struct mlx5_umr_wr *umrwr = umr_wr(wr);
 
@@ -3172,7 +3174,7 @@ static void set_reg_umr_segment(struct mlx5_wqe_umr_ctrl_seg *umr,
 		if (wr->send_flags & MLX5_IB_SEND_UMR_UPDATE_PD)
 			umr->mkey_mask |= get_umr_update_pd_mask();
 		if (!umr->mkey_mask)
-			umr->mkey_mask = get_umr_reg_mr_mask();
+			umr->mkey_mask = get_umr_reg_mr_mask(atomic);
 	} else {
 		umr->mkey_mask = get_umr_unreg_mr_mask();
 	}
@@ -4025,7 +4027,7 @@ int mlx5_ib_post_send(struct ib_qp *ibqp, struct ib_send_wr *wr,
 			}
 			qp->sq.wr_data[idx] = MLX5_IB_WR_UMR;
 			ctrl->imm = cpu_to_be32(umr_wr(wr)->mkey);
-			set_reg_umr_segment(seg, wr);
+			set_reg_umr_segment(seg, wr, !!(MLX5_CAP_GEN(mdev, atomic)));
 			seg += sizeof(struct mlx5_wqe_umr_ctrl_seg);
 			size += sizeof(struct mlx5_wqe_umr_ctrl_seg) / 16;
 			if (unlikely((seg == qend)))