diff --git a/drivers/misc/habanalabs/common/command_submission.c b/drivers/misc/habanalabs/common/command_submission.c index f7fac82ac41d..3affb350070c 100644 --- a/drivers/misc/habanalabs/common/command_submission.c +++ b/drivers/misc/habanalabs/common/command_submission.c @@ -479,6 +479,9 @@ static int allocate_cs(struct hl_device *hdev, struct hl_ctx *ctx, return -ENOMEM; } + /* increment refcnt for context */ + hl_ctx_get(hdev, ctx); + cs->ctx = ctx; cs->submitted = false; cs->completed = false; @@ -550,6 +553,7 @@ free_fence: kfree(cs_cmpl); free_cs: kfree(cs); + hl_ctx_put(ctx); return rc; } @@ -827,14 +831,9 @@ static int cs_ioctl_default(struct hl_fpriv *hpriv, void __user *chunks, if (rc) goto out; - /* increment refcnt for context */ - hl_ctx_get(hdev, hpriv->ctx); - rc = allocate_cs(hdev, hpriv->ctx, CS_TYPE_DEFAULT, &cs); - if (rc) { - hl_ctx_put(hpriv->ctx); + if (rc) goto free_cs_chunk_array; - } cs->timestamp = !!(flags & HL_CS_FLAGS_TIMESTAMP); *cs_seq = cs->sequence; @@ -1276,15 +1275,11 @@ static int cs_ioctl_signal_wait(struct hl_fpriv *hpriv, enum hl_cs_type cs_type, } } - /* increment refcnt for context */ - hl_ctx_get(hdev, ctx); - rc = allocate_cs(hdev, ctx, cs_type, &cs); if (rc) { if (cs_type == CS_TYPE_WAIT || cs_type == CS_TYPE_COLLECTIVE_WAIT) hl_fence_put(sig_fence); - hl_ctx_put(ctx); goto free_cs_chunk_array; }