diff --git a/drivers/vfio/group.c b/drivers/vfio/group.c index 5c17ad812313..610a429c6191 100644 --- a/drivers/vfio/group.c +++ b/drivers/vfio/group.c @@ -682,16 +682,6 @@ static struct vfio_group *vfio_group_find_or_alloc(struct device *dev) if (!iommu_group) return ERR_PTR(-EINVAL); - /* - * VFIO always sets IOMMU_CACHE because we offer no way for userspace to - * restore cache coherency. It has to be checked here because it is only - * valid for cases where we are using iommu groups. - */ - if (!device_iommu_capable(dev, IOMMU_CAP_CACHE_COHERENCY)) { - iommu_group_put(iommu_group); - return ERR_PTR(-EINVAL); - } - mutex_lock(&vfio.group_lock); group = vfio_group_find_from_iommu(iommu_group); if (group) { diff --git a/drivers/vfio/vfio_main.c b/drivers/vfio/vfio_main.c index ba1d84afe081..902f06e52c48 100644 --- a/drivers/vfio/vfio_main.c +++ b/drivers/vfio/vfio_main.c @@ -292,6 +292,17 @@ static int __vfio_register_dev(struct vfio_device *device, if (ret) return ret; + /* + * VFIO always sets IOMMU_CACHE because we offer no way for userspace to + * restore cache coherency. It has to be checked here because it is only + * valid for cases where we are using iommu groups. + */ + if (type == VFIO_IOMMU && !vfio_device_is_noiommu(device) && + !device_iommu_capable(device->dev, IOMMU_CAP_CACHE_COHERENCY)) { + ret = -EINVAL; + goto err_out; + } + ret = vfio_device_add(device); if (ret) goto err_out;