io_uring/tctx: have io_uring_alloc_task_context() return tctx

Instead of having io_uring_alloc_task_context() return an int and
assign tsk->io_uring, just have it return the task context directly.
This enables cleaner error handling in callers, which may have
failure points post calling io_uring_alloc_task_context().

Signed-off-by: Jens Axboe <axboe@kernel.dk>
master
Jens Axboe 2026-04-08 11:31:05 -06:00
parent f847bf6d29
commit 2c453a4281
3 changed files with 19 additions and 14 deletions

View File

@ -458,6 +458,7 @@ __cold int io_sq_offload_create(struct io_ring_ctx *ctx,
return -EINVAL; return -EINVAL;
} }
if (ctx->flags & IORING_SETUP_SQPOLL) { if (ctx->flags & IORING_SETUP_SQPOLL) {
struct io_uring_task *tctx;
struct task_struct *tsk; struct task_struct *tsk;
struct io_sq_data *sqd; struct io_sq_data *sqd;
bool attached; bool attached;
@ -524,8 +525,13 @@ __cold int io_sq_offload_create(struct io_ring_ctx *ctx,
rcu_assign_pointer(sqd->thread, tsk); rcu_assign_pointer(sqd->thread, tsk);
mutex_unlock(&sqd->lock); mutex_unlock(&sqd->lock);
ret = 0;
get_task_struct(tsk); get_task_struct(tsk);
ret = io_uring_alloc_task_context(tsk, ctx); tctx = io_uring_alloc_task_context(tsk, ctx);
if (!IS_ERR(tctx))
tsk->io_uring = tctx;
else
ret = PTR_ERR(tctx);
wake_up_new_task(tsk); wake_up_new_task(tsk);
if (ret) if (ret)
goto err; goto err;

View File

@ -74,7 +74,7 @@ void __io_uring_free(struct task_struct *tsk)
} }
} }
__cold int io_uring_alloc_task_context(struct task_struct *task, __cold struct io_uring_task *io_uring_alloc_task_context(struct task_struct *task,
struct io_ring_ctx *ctx) struct io_ring_ctx *ctx)
{ {
struct io_uring_task *tctx; struct io_uring_task *tctx;
@ -82,12 +82,12 @@ __cold int io_uring_alloc_task_context(struct task_struct *task,
tctx = kzalloc_obj(*tctx); tctx = kzalloc_obj(*tctx);
if (unlikely(!tctx)) if (unlikely(!tctx))
return -ENOMEM; return ERR_PTR(-ENOMEM);
ret = percpu_counter_init(&tctx->inflight, 0, GFP_KERNEL); ret = percpu_counter_init(&tctx->inflight, 0, GFP_KERNEL);
if (unlikely(ret)) { if (unlikely(ret)) {
kfree(tctx); kfree(tctx);
return ret; return ERR_PTR(ret);
} }
tctx->io_wq = io_init_wq_offload(ctx, task); tctx->io_wq = io_init_wq_offload(ctx, task);
@ -95,7 +95,7 @@ __cold int io_uring_alloc_task_context(struct task_struct *task,
ret = PTR_ERR(tctx->io_wq); ret = PTR_ERR(tctx->io_wq);
percpu_counter_destroy(&tctx->inflight); percpu_counter_destroy(&tctx->inflight);
kfree(tctx); kfree(tctx);
return ret; return ERR_PTR(ret);
} }
tctx->task = task; tctx->task = task;
@ -103,10 +103,9 @@ __cold int io_uring_alloc_task_context(struct task_struct *task,
init_waitqueue_head(&tctx->wait); init_waitqueue_head(&tctx->wait);
atomic_set(&tctx->in_cancel, 0); atomic_set(&tctx->in_cancel, 0);
atomic_set(&tctx->inflight_tracked, 0); atomic_set(&tctx->inflight_tracked, 0);
task->io_uring = tctx;
init_llist_head(&tctx->task_list); init_llist_head(&tctx->task_list);
init_task_work(&tctx->task_work, tctx_task_work); init_task_work(&tctx->task_work, tctx_task_work);
return 0; return tctx;
} }
int __io_uring_add_tctx_node(struct io_ring_ctx *ctx) int __io_uring_add_tctx_node(struct io_ring_ctx *ctx)
@ -116,11 +115,11 @@ int __io_uring_add_tctx_node(struct io_ring_ctx *ctx)
int ret; int ret;
if (unlikely(!tctx)) { if (unlikely(!tctx)) {
ret = io_uring_alloc_task_context(current, ctx); tctx = io_uring_alloc_task_context(current, ctx);
if (unlikely(ret)) if (IS_ERR(tctx))
return ret; return PTR_ERR(tctx);
tctx = current->io_uring; current->io_uring = tctx;
if (ctx->int_flags & IO_RING_F_IOWQ_LIMITS_SET) { if (ctx->int_flags & IO_RING_F_IOWQ_LIMITS_SET) {
unsigned int limits[2] = { ctx->iowq_limits[0], unsigned int limits[2] = { ctx->iowq_limits[0],
ctx->iowq_limits[1], }; ctx->iowq_limits[1], };

View File

@ -6,7 +6,7 @@ struct io_tctx_node {
struct io_ring_ctx *ctx; struct io_ring_ctx *ctx;
}; };
int io_uring_alloc_task_context(struct task_struct *task, struct io_uring_task *io_uring_alloc_task_context(struct task_struct *task,
struct io_ring_ctx *ctx); struct io_ring_ctx *ctx);
void io_uring_del_tctx_node(unsigned long index); void io_uring_del_tctx_node(unsigned long index);
int __io_uring_add_tctx_node(struct io_ring_ctx *ctx); int __io_uring_add_tctx_node(struct io_ring_ctx *ctx);