diff --git a/fs/io_uring.c b/fs/io_uring.c index 6ad6c56e7f4d..dae87c576943 100644 --- a/fs/io_uring.c +++ b/fs/io_uring.c @@ -710,6 +710,7 @@ enum { REQ_F_REISSUE_BIT, REQ_F_DONT_REISSUE_BIT, REQ_F_CREDS_BIT, + REQ_F_REFCOUNT_BIT, /* keep async read/write and isreg together and in order */ REQ_F_NOWAIT_READ_BIT, REQ_F_NOWAIT_WRITE_BIT, @@ -765,6 +766,8 @@ enum { REQ_F_ISREG = BIT(REQ_F_ISREG_BIT), /* has creds assigned */ REQ_F_CREDS = BIT(REQ_F_CREDS_BIT), + /* skip refcounting if not set */ + REQ_F_REFCOUNT = BIT(REQ_F_REFCOUNT_BIT), }; struct async_poll { @@ -1087,26 +1090,40 @@ EXPORT_SYMBOL(io_uring_get_socket); static inline bool req_ref_inc_not_zero(struct io_kiocb *req) { + WARN_ON_ONCE(!(req->flags & REQ_F_REFCOUNT)); return atomic_inc_not_zero(&req->refs); } static inline bool req_ref_put_and_test(struct io_kiocb *req) { + if (likely(!(req->flags & REQ_F_REFCOUNT))) + return true; + WARN_ON_ONCE(req_ref_zero_or_close_to_overflow(req)); return atomic_dec_and_test(&req->refs); } static inline void req_ref_put(struct io_kiocb *req) { + WARN_ON_ONCE(!(req->flags & REQ_F_REFCOUNT)); WARN_ON_ONCE(req_ref_put_and_test(req)); } static inline void req_ref_get(struct io_kiocb *req) { + WARN_ON_ONCE(!(req->flags & REQ_F_REFCOUNT)); WARN_ON_ONCE(req_ref_zero_or_close_to_overflow(req)); atomic_inc(&req->refs); } +static inline void io_req_refcount(struct io_kiocb *req) +{ + if (!(req->flags & REQ_F_REFCOUNT)) { + req->flags |= REQ_F_REFCOUNT; + atomic_set(&req->refs, 1); + } +} + static inline void io_req_set_rsrc_node(struct io_kiocb *req) { struct io_ring_ctx *ctx = req->ctx; @@ -5192,6 +5209,7 @@ static int io_arm_poll_handler(struct io_kiocb *req) req->apoll = apoll; req->flags |= REQ_F_POLLED; ipt.pt._qproc = io_async_queue_proc; + io_req_refcount(req); ret = __io_arm_poll_handler(req, &apoll->poll, &ipt, mask, io_async_wake); @@ -5382,6 +5400,7 @@ static int io_poll_add_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe if (flags & ~IORING_POLL_ADD_MULTI) return -EINVAL; + io_req_refcount(req); poll->events = io_poll_parse_events(sqe, flags); return 0; } @@ -6273,6 +6292,7 @@ static void io_wq_submit_work(struct io_wq_work *work) struct io_kiocb *timeout; int ret = 0; + io_req_refcount(req); /* will be dropped by ->io_free_work() after returning to io-wq */ req_ref_get(req); @@ -6442,7 +6462,10 @@ static struct io_kiocb *io_prep_linked_timeout(struct io_kiocb *req) return NULL; /* linked timeouts should have two refs once prep'ed */ + io_req_refcount(req); + io_req_refcount(nxt); req_ref_get(nxt); + nxt->timeout.head = req; nxt->flags |= REQ_F_LTIMEOUT_ACTIVE; req->flags |= REQ_F_LINK_TIMEOUT; @@ -6549,7 +6572,6 @@ static int io_init_req(struct io_ring_ctx *ctx, struct io_kiocb *req, req->user_data = READ_ONCE(sqe->user_data); req->file = NULL; req->fixed_rsrc_refs = NULL; - atomic_set(&req->refs, 1); req->task = current; /* enforce forwards compatibility on users */