diff --git a/io_uring/io_uring.c b/io_uring/io_uring.c index 6eaa21e09469..34104c256c88 100644 --- a/io_uring/io_uring.c +++ b/io_uring/io_uring.c @@ -2308,6 +2308,10 @@ static __cold void io_ring_exit_work(struct work_struct *work) struct io_tctx_node *node; int ret; + mutex_lock(&ctx->uring_lock); + io_terminate_zcrx(ctx); + mutex_unlock(&ctx->uring_lock); + /* * If we're doing polled IO and end up having requests being * submitted async (out-of-line), then completions can come in while diff --git a/io_uring/zcrx.c b/io_uring/zcrx.c index 73fa82759771..615805d2c3dd 100644 --- a/io_uring/zcrx.c +++ b/io_uring/zcrx.c @@ -624,12 +624,17 @@ static void io_zcrx_scrub(struct io_zcrx_ifq *ifq) } } -static void zcrx_unregister(struct io_zcrx_ifq *ifq) +static void zcrx_unregister_user(struct io_zcrx_ifq *ifq) { if (refcount_dec_and_test(&ifq->user_refs)) { io_close_queue(ifq); io_zcrx_scrub(ifq); } +} + +static void zcrx_unregister(struct io_zcrx_ifq *ifq) +{ + zcrx_unregister_user(ifq); io_put_zcrx_ifq(ifq); } @@ -887,6 +892,36 @@ static struct net_iov *__io_zcrx_get_free_niov(struct io_zcrx_area *area) return &area->nia.niovs[niov_idx]; } +static inline bool is_zcrx_entry_marked(struct io_ring_ctx *ctx, unsigned long id) +{ + return xa_get_mark(&ctx->zcrx_ctxs, id, XA_MARK_0); +} + +static inline void set_zcrx_entry_mark(struct io_ring_ctx *ctx, unsigned long id) +{ + xa_set_mark(&ctx->zcrx_ctxs, id, XA_MARK_0); +} + +void io_terminate_zcrx(struct io_ring_ctx *ctx) +{ + struct io_zcrx_ifq *ifq; + unsigned long id = 0; + + lockdep_assert_held(&ctx->uring_lock); + + while (1) { + scoped_guard(mutex, &ctx->mmap_lock) + ifq = xa_find(&ctx->zcrx_ctxs, &id, ULONG_MAX, XA_PRESENT); + if (!ifq) + break; + if (WARN_ON_ONCE(is_zcrx_entry_marked(ctx, id))) + break; + set_zcrx_entry_mark(ctx, id); + id++; + zcrx_unregister_user(ifq); + } +} + void io_unregister_zcrx_ifqs(struct io_ring_ctx *ctx) { struct io_zcrx_ifq *ifq; @@ -898,12 +933,17 @@ void io_unregister_zcrx_ifqs(struct io_ring_ctx *ctx) unsigned long id = 0; ifq = xa_find(&ctx->zcrx_ctxs, &id, ULONG_MAX, XA_PRESENT); - if (ifq) + if (ifq) { + if (WARN_ON_ONCE(!is_zcrx_entry_marked(ctx, id))) { + ifq = NULL; + break; + } xa_erase(&ctx->zcrx_ctxs, id); + } } if (!ifq) break; - zcrx_unregister(ifq); + io_put_zcrx_ifq(ifq); } xa_destroy(&ctx->zcrx_ctxs); diff --git a/io_uring/zcrx.h b/io_uring/zcrx.h index 0ddcf0ee8861..0316a41a3561 100644 --- a/io_uring/zcrx.h +++ b/io_uring/zcrx.h @@ -74,6 +74,7 @@ int io_zcrx_ctrl(struct io_ring_ctx *ctx, void __user *arg, unsigned nr_arg); int io_register_zcrx_ifq(struct io_ring_ctx *ctx, struct io_uring_zcrx_ifq_reg __user *arg); void io_unregister_zcrx_ifqs(struct io_ring_ctx *ctx); +void io_terminate_zcrx(struct io_ring_ctx *ctx); int io_zcrx_recv(struct io_kiocb *req, struct io_zcrx_ifq *ifq, struct socket *sock, unsigned int flags, unsigned issue_flags, unsigned int *len); @@ -88,6 +89,9 @@ static inline int io_register_zcrx_ifq(struct io_ring_ctx *ctx, static inline void io_unregister_zcrx_ifqs(struct io_ring_ctx *ctx) { } +static inline void io_terminate_zcrx(struct io_ring_ctx *ctx) +{ +} static inline int io_zcrx_recv(struct io_kiocb *req, struct io_zcrx_ifq *ifq, struct socket *sock, unsigned int flags, unsigned issue_flags, unsigned int *len)