diff --git a/drivers/cxl/core/memdev.c b/drivers/cxl/core/memdev.c index 63353d990374..4c2e24a1a89c 100644 --- a/drivers/cxl/core/memdev.c +++ b/drivers/cxl/core/memdev.c @@ -556,20 +556,11 @@ void clear_exclusive_cxl_commands(struct cxl_memdev_state *mds, } EXPORT_SYMBOL_NS_GPL(clear_exclusive_cxl_commands, CXL); -static void cxl_memdev_security_shutdown(struct device *dev) -{ - struct cxl_memdev *cxlmd = to_cxl_memdev(dev); - struct cxl_memdev_state *mds = to_cxl_memdev_state(cxlmd->cxlds); - - cancel_delayed_work_sync(&mds->security.poll_dwork); -} - static void cxl_memdev_shutdown(struct device *dev) { struct cxl_memdev *cxlmd = to_cxl_memdev(dev); down_write(&cxl_memdev_rwsem); - cxl_memdev_security_shutdown(dev); cxlmd->cxlds = NULL; up_write(&cxl_memdev_rwsem); } @@ -991,35 +982,6 @@ static const struct file_operations cxl_memdev_fops = { .llseek = noop_llseek, }; -static void put_sanitize(void *data) -{ - struct cxl_memdev_state *mds = data; - - sysfs_put(mds->security.sanitize_node); -} - -static int cxl_memdev_security_init(struct cxl_memdev *cxlmd) -{ - struct cxl_dev_state *cxlds = cxlmd->cxlds; - struct cxl_memdev_state *mds = to_cxl_memdev_state(cxlds); - struct device *dev = &cxlmd->dev; - struct kernfs_node *sec; - - sec = sysfs_get_dirent(dev->kobj.sd, "security"); - if (!sec) { - dev_err(dev, "sysfs_get_dirent 'security' failed\n"); - return -ENODEV; - } - mds->security.sanitize_node = sysfs_get_dirent(sec, "state"); - sysfs_put(sec); - if (!mds->security.sanitize_node) { - dev_err(dev, "sysfs_get_dirent 'state' failed\n"); - return -ENODEV; - } - - return devm_add_action_or_reset(cxlds->dev, put_sanitize, mds); -} - struct cxl_memdev *devm_cxl_add_memdev(struct device *host, struct cxl_dev_state *cxlds) { @@ -1049,10 +1011,6 @@ struct cxl_memdev *devm_cxl_add_memdev(struct device *host, if (rc) goto err; - rc = cxl_memdev_security_init(cxlmd); - if (rc) - goto err; - rc = devm_add_action_or_reset(host, cxl_memdev_unregister, cxlmd); if (rc) return ERR_PTR(rc); @@ -1069,6 +1027,50 @@ err: } EXPORT_SYMBOL_NS_GPL(devm_cxl_add_memdev, CXL); +static void sanitize_teardown_notifier(void *data) +{ + struct cxl_memdev_state *mds = data; + struct kernfs_node *state; + + /* + * Prevent new irq triggered invocations of the workqueue and + * flush inflight invocations. + */ + mutex_lock(&mds->mbox_mutex); + state = mds->security.sanitize_node; + mds->security.sanitize_node = NULL; + mutex_unlock(&mds->mbox_mutex); + + cancel_delayed_work_sync(&mds->security.poll_dwork); + sysfs_put(state); +} + +int devm_cxl_sanitize_setup_notifier(struct device *host, + struct cxl_memdev *cxlmd) +{ + struct cxl_dev_state *cxlds = cxlmd->cxlds; + struct cxl_memdev_state *mds = to_cxl_memdev_state(cxlds); + struct kernfs_node *sec; + + if (!test_bit(CXL_SEC_ENABLED_SANITIZE, mds->security.enabled_cmds)) + return 0; + + /* + * Note, the expectation is that @cxlmd would have failed to be + * created if these sysfs_get_dirent calls fail. + */ + sec = sysfs_get_dirent(cxlmd->dev.kobj.sd, "security"); + if (!sec) + return -ENOENT; + mds->security.sanitize_node = sysfs_get_dirent(sec, "state"); + sysfs_put(sec); + if (!mds->security.sanitize_node) + return -ENOENT; + + return devm_add_action_or_reset(host, sanitize_teardown_notifier, mds); +} +EXPORT_SYMBOL_NS_GPL(devm_cxl_sanitize_setup_notifier, CXL); + __init int cxl_memdev_init(void) { dev_t devt; diff --git a/drivers/cxl/cxlmem.h b/drivers/cxl/cxlmem.h index fdb2c8dd98d0..fbdee1d63717 100644 --- a/drivers/cxl/cxlmem.h +++ b/drivers/cxl/cxlmem.h @@ -86,6 +86,8 @@ static inline bool is_cxl_endpoint(struct cxl_port *port) struct cxl_memdev *devm_cxl_add_memdev(struct device *host, struct cxl_dev_state *cxlds); +int devm_cxl_sanitize_setup_notifier(struct device *host, + struct cxl_memdev *cxlmd); struct cxl_memdev_state; int devm_cxl_setup_fw_upload(struct device *host, struct cxl_memdev_state *mds); int devm_cxl_dpa_reserve(struct cxl_endpoint_decoder *cxled, diff --git a/drivers/cxl/pci.c b/drivers/cxl/pci.c index 58894b8b59da..05f3d14921e6 100644 --- a/drivers/cxl/pci.c +++ b/drivers/cxl/pci.c @@ -875,6 +875,10 @@ static int cxl_pci_probe(struct pci_dev *pdev, const struct pci_device_id *id) if (rc) return rc; + rc = devm_cxl_sanitize_setup_notifier(&pdev->dev, cxlmd); + if (rc) + return rc; + pmu_count = cxl_count_regblock(pdev, CXL_REGLOC_RBI_PMU); for (i = 0; i < pmu_count; i++) { struct cxl_pmu_regs pmu_regs;