diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c index 88362c0afe45..67bd947cc556 100644 --- a/drivers/vhost/vhost.c +++ b/drivers/vhost/vhost.c @@ -664,16 +664,22 @@ static void __vhost_vq_attach_worker(struct vhost_virtqueue *vq, { struct vhost_worker *old_worker; - old_worker = rcu_dereference_check(vq->worker, - lockdep_is_held(&vq->dev->mutex)); - mutex_lock(&worker->mutex); - worker->attachment_cnt++; - mutex_unlock(&worker->mutex); - rcu_assign_pointer(vq->worker, worker); + mutex_lock(&vq->mutex); - if (!old_worker) + old_worker = rcu_dereference_check(vq->worker, + lockdep_is_held(&vq->mutex)); + rcu_assign_pointer(vq->worker, worker); + worker->attachment_cnt++; + + if (!old_worker) { + mutex_unlock(&vq->mutex); + mutex_unlock(&worker->mutex); return; + } + mutex_unlock(&vq->mutex); + mutex_unlock(&worker->mutex); + /* * Take the worker mutex to make sure we see the work queued from * device wide flushes which doesn't use RCU for execution.