diff --git a/drivers/net/netdevsim/dev.c b/drivers/net/netdevsim/dev.c index b962fc8e1397..738784fda117 100644 --- a/drivers/net/netdevsim/dev.c +++ b/drivers/net/netdevsim/dev.c @@ -1556,14 +1556,18 @@ int nsim_drv_probe(struct nsim_bus_dev *nsim_bus_dev) goto err_devlink_unlock; } - err = nsim_dev_resources_register(devlink); + err = devl_register(devlink); if (err) goto err_vfc_free; + err = nsim_dev_resources_register(devlink); + if (err) + goto err_dl_unregister; + err = devlink_params_register(devlink, nsim_devlink_params, ARRAY_SIZE(nsim_devlink_params)); if (err) - goto err_dl_unregister; + goto err_resource_unregister; nsim_devlink_set_params_init_values(nsim_dev, devlink); err = nsim_dev_dummy_region_init(nsim_dev, devlink); @@ -1607,7 +1611,6 @@ int nsim_drv_probe(struct nsim_bus_dev *nsim_bus_dev) nsim_dev->esw_mode = DEVLINK_ESWITCH_MODE_LEGACY; devlink_set_features(devlink, DEVLINK_F_RELOAD); devl_unlock(devlink); - devlink_register(devlink); return 0; err_hwstats_exit: @@ -1629,8 +1632,10 @@ err_dummy_region_exit: err_params_unregister: devlink_params_unregister(devlink, nsim_devlink_params, ARRAY_SIZE(nsim_devlink_params)); -err_dl_unregister: +err_resource_unregister: devl_resources_unregister(devlink); +err_dl_unregister: + devl_unregister(devlink); err_vfc_free: kfree(nsim_dev->vfconfigs); err_devlink_unlock: @@ -1668,7 +1673,6 @@ void nsim_drv_remove(struct nsim_bus_dev *nsim_bus_dev) struct nsim_dev *nsim_dev = dev_get_drvdata(&nsim_bus_dev->dev); struct devlink *devlink = priv_to_devlink(nsim_dev); - devlink_unregister(devlink); devl_lock(devlink); nsim_dev_reload_destroy(nsim_dev); @@ -1677,6 +1681,7 @@ void nsim_drv_remove(struct nsim_bus_dev *nsim_bus_dev) devlink_params_unregister(devlink, nsim_devlink_params, ARRAY_SIZE(nsim_devlink_params)); devl_resources_unregister(devlink); + devl_unregister(devlink); kfree(nsim_dev->vfconfigs); kfree(nsim_dev->fa_cookie); devl_unlock(devlink); diff --git a/include/net/devlink.h b/include/net/devlink.h index 6a2e4f21779f..425ecef431b7 100644 --- a/include/net/devlink.h +++ b/include/net/devlink.h @@ -1647,6 +1647,8 @@ static inline struct devlink *devlink_alloc(const struct devlink_ops *ops, return devlink_alloc_ns(ops, priv_size, &init_net, dev); } void devlink_set_features(struct devlink *devlink, u64 features); +int devl_register(struct devlink *devlink); +void devl_unregister(struct devlink *devlink); void devlink_register(struct devlink *devlink); void devlink_unregister(struct devlink *devlink); void devlink_free(struct devlink *devlink); diff --git a/net/devlink/core.c b/net/devlink/core.c index 371d6821315d..a31a317626d7 100644 --- a/net/devlink/core.c +++ b/net/devlink/core.c @@ -67,6 +67,15 @@ void devl_unlock(struct devlink *devlink) } EXPORT_SYMBOL_GPL(devl_unlock); +/** + * devlink_try_get() - try to obtain a reference on a devlink instance + * @devlink: instance to reference + * + * Obtain a reference on a devlink instance. A reference on a devlink instance + * only implies that it's safe to take the instance lock. It does not imply + * that the instance is registered, use devl_is_registered() after taking + * the instance lock to check registration status. + */ struct devlink *__must_check devlink_try_get(struct devlink *devlink) { if (refcount_inc_not_zero(&devlink->refcount)) @@ -74,66 +83,35 @@ struct devlink *__must_check devlink_try_get(struct devlink *devlink) return NULL; } -static void __devlink_put_rcu(struct rcu_head *head) -{ - struct devlink *devlink = container_of(head, struct devlink, rcu); - - complete(&devlink->comp); -} - void devlink_put(struct devlink *devlink) { if (refcount_dec_and_test(&devlink->refcount)) - /* Make sure unregister operation that may await the completion - * is unblocked only after all users are after the end of - * RCU grace period. - */ - call_rcu(&devlink->rcu, __devlink_put_rcu); + kfree_rcu(devlink, rcu); } -struct devlink * -devlinks_xa_find_get(struct net *net, unsigned long *indexp, - void * (*xa_find_fn)(struct xarray *, unsigned long *, - unsigned long, xa_mark_t)) +struct devlink *devlinks_xa_find_get(struct net *net, unsigned long *indexp) { - struct devlink *devlink; + struct devlink *devlink = NULL; rcu_read_lock(); retry: - devlink = xa_find_fn(&devlinks, indexp, ULONG_MAX, DEVLINK_REGISTERED); + devlink = xa_find(&devlinks, indexp, ULONG_MAX, DEVLINK_REGISTERED); if (!devlink) goto unlock; - /* In case devlink_unregister() was already called and "unregistering" - * mark was set, do not allow to get a devlink reference here. - * This prevents live-lock of devlink_unregister() wait for completion. - */ - if (xa_get_mark(&devlinks, *indexp, DEVLINK_UNREGISTERING)) - goto retry; - - /* For a possible retry, the xa_find_after() should be always used */ - xa_find_fn = xa_find_after; if (!devlink_try_get(devlink)) - goto retry; + goto next; if (!net_eq(devlink_net(devlink), net)) { devlink_put(devlink); - goto retry; + goto next; } unlock: rcu_read_unlock(); return devlink; -} -struct devlink * -devlinks_xa_find_get_first(struct net *net, unsigned long *indexp) -{ - return devlinks_xa_find_get(net, indexp, xa_find); -} - -struct devlink * -devlinks_xa_find_get_next(struct net *net, unsigned long *indexp) -{ - return devlinks_xa_find_get(net, indexp, xa_find_after); +next: + (*indexp)++; + goto retry; } /** @@ -147,8 +125,6 @@ devlinks_xa_find_get_next(struct net *net, unsigned long *indexp) */ void devlink_set_features(struct devlink *devlink, u64 features) { - ASSERT_DEVLINK_NOT_REGISTERED(devlink); - WARN_ON(features & DEVLINK_F_RELOAD && !devlink_reload_supported(devlink->ops)); devlink->features = features; @@ -156,37 +132,48 @@ void devlink_set_features(struct devlink *devlink, u64 features) EXPORT_SYMBOL_GPL(devlink_set_features); /** - * devlink_register - Register devlink instance - * - * @devlink: devlink + * devl_register - Register devlink instance + * @devlink: devlink */ -void devlink_register(struct devlink *devlink) +int devl_register(struct devlink *devlink) { ASSERT_DEVLINK_NOT_REGISTERED(devlink); - /* Make sure that we are in .probe() routine */ + devl_assert_locked(devlink); xa_set_mark(&devlinks, devlink->index, DEVLINK_REGISTERED); devlink_notify_register(devlink); + + return 0; +} +EXPORT_SYMBOL_GPL(devl_register); + +void devlink_register(struct devlink *devlink) +{ + devl_lock(devlink); + devl_register(devlink); + devl_unlock(devlink); } EXPORT_SYMBOL_GPL(devlink_register); /** - * devlink_unregister - Unregister devlink instance - * - * @devlink: devlink + * devl_unregister - Unregister devlink instance + * @devlink: devlink */ -void devlink_unregister(struct devlink *devlink) +void devl_unregister(struct devlink *devlink) { ASSERT_DEVLINK_REGISTERED(devlink); - /* Make sure that we are in .remove() routine */ - - xa_set_mark(&devlinks, devlink->index, DEVLINK_UNREGISTERING); - devlink_put(devlink); - wait_for_completion(&devlink->comp); + devl_assert_locked(devlink); devlink_notify_unregister(devlink); xa_clear_mark(&devlinks, devlink->index, DEVLINK_REGISTERED); - xa_clear_mark(&devlinks, devlink->index, DEVLINK_UNREGISTERING); +} +EXPORT_SYMBOL_GPL(devl_unregister); + +void devlink_unregister(struct devlink *devlink) +{ + devl_lock(devlink); + devl_unregister(devlink); + devl_unlock(devlink); } EXPORT_SYMBOL_GPL(devlink_unregister); @@ -250,7 +237,6 @@ struct devlink *devlink_alloc_ns(const struct devlink_ops *ops, mutex_init(&devlink->reporters_lock); mutex_init(&devlink->linecards_lock); refcount_set(&devlink->refcount, 1); - init_completion(&devlink->comp); return devlink; @@ -296,7 +282,7 @@ void devlink_free(struct devlink *devlink) xa_erase(&devlinks, devlink->index); - kfree(devlink); + devlink_put(devlink); } EXPORT_SYMBOL_GPL(devlink_free); @@ -312,15 +298,18 @@ static void __net_exit devlink_pernet_pre_exit(struct net *net) */ devlinks_xa_for_each_registered_get(net, index, devlink) { WARN_ON(!(devlink->features & DEVLINK_F_RELOAD)); - mutex_lock(&devlink->lock); - err = devlink_reload(devlink, &init_net, - DEVLINK_RELOAD_ACTION_DRIVER_REINIT, - DEVLINK_RELOAD_LIMIT_UNSPEC, - &actions_performed, NULL); - mutex_unlock(&devlink->lock); + devl_lock(devlink); + err = 0; + if (devl_is_registered(devlink)) + err = devlink_reload(devlink, &init_net, + DEVLINK_RELOAD_ACTION_DRIVER_REINIT, + DEVLINK_RELOAD_LIMIT_UNSPEC, + &actions_performed, NULL); + devl_unlock(devlink); + devlink_put(devlink); + if (err && err != -EOPNOTSUPP) pr_warn("Failed to reload devlink instance into init_net\n"); - devlink_put(devlink); } } diff --git a/net/devlink/devl_internal.h b/net/devlink/devl_internal.h index adf9f6c177db..5d2bbe295659 100644 --- a/net/devlink/devl_internal.h +++ b/net/devlink/devl_internal.h @@ -12,7 +12,6 @@ #include #define DEVLINK_REGISTERED XA_MARK_1 -#define DEVLINK_UNREGISTERING XA_MARK_2 #define DEVLINK_RELOAD_STATS_ARRAY_SIZE \ (__DEVLINK_RELOAD_LIMIT_MAX * __DEVLINK_RELOAD_ACTION_MAX) @@ -52,7 +51,6 @@ struct devlink { struct lock_class_key lock_key; u8 reload_failed:1; refcount_t refcount; - struct completion comp; struct rcu_head rcu; struct notifier_block netdevice_nb; char priv[] __aligned(NETDEV_ALIGN); @@ -82,18 +80,17 @@ extern struct genl_family devlink_nl_family; * in loop body in order to release the reference. */ #define devlinks_xa_for_each_registered_get(net, index, devlink) \ - for (index = 0, \ - devlink = devlinks_xa_find_get_first(net, &index); \ - devlink; devlink = devlinks_xa_find_get_next(net, &index)) + for (index = 0; (devlink = devlinks_xa_find_get(net, &index)); index++) -struct devlink * -devlinks_xa_find_get(struct net *net, unsigned long *indexp, - void * (*xa_find_fn)(struct xarray *, unsigned long *, - unsigned long, xa_mark_t)); -struct devlink * -devlinks_xa_find_get_first(struct net *net, unsigned long *indexp); -struct devlink * -devlinks_xa_find_get_next(struct net *net, unsigned long *indexp); +struct devlink *devlinks_xa_find_get(struct net *net, unsigned long *indexp); + +static inline bool devl_is_registered(struct devlink *devlink) +{ + /* To prevent races the caller must hold the instance lock + * or another lock taken during unregistration. + */ + return xa_get_mark(&devlinks, devlink->index, DEVLINK_REGISTERED); +} /* Netlink */ #define DEVLINK_NL_FLAG_NEED_PORT BIT(0) @@ -135,12 +132,13 @@ struct devlink_gen_cmd { */ #define devlink_dump_for_each_instance_get(msg, state, devlink) \ for (; (devlink = devlinks_xa_find_get(sock_net(msg->sk), \ - &state->instance, xa_find)); \ + &state->instance)); \ state->instance++, state->idx = 0) extern const struct genl_small_ops devlink_nl_ops[56]; -struct devlink *devlink_get_from_attrs(struct net *net, struct nlattr **attrs); +struct devlink * +devlink_get_from_attrs_lock(struct net *net, struct nlattr **attrs); void devlink_notify_unregister(struct devlink *devlink); void devlink_notify_register(struct devlink *devlink); diff --git a/net/devlink/leftover.c b/net/devlink/leftover.c index e6d6c7f74ae7..1e23b2da78cc 100644 --- a/net/devlink/leftover.c +++ b/net/devlink/leftover.c @@ -2130,6 +2130,9 @@ static int devlink_nl_cmd_linecard_get_dumpit(struct sk_buff *msg, int idx = 0; mutex_lock(&devlink->linecards_lock); + if (!devl_is_registered(devlink)) + goto next_devlink; + list_for_each_entry(linecard, &devlink->linecard_list, list) { if (idx < state->idx) { idx++; @@ -2151,6 +2154,7 @@ static int devlink_nl_cmd_linecard_get_dumpit(struct sk_buff *msg, } idx++; } +next_devlink: mutex_unlock(&devlink->linecards_lock); devlink_put(devlink); } @@ -5259,7 +5263,13 @@ static void devlink_param_notify(struct devlink *devlink, WARN_ON(cmd != DEVLINK_CMD_PARAM_NEW && cmd != DEVLINK_CMD_PARAM_DEL && cmd != DEVLINK_CMD_PORT_PARAM_NEW && cmd != DEVLINK_CMD_PORT_PARAM_DEL); - ASSERT_DEVLINK_REGISTERED(devlink); + + /* devlink_notify_register() / devlink_notify_unregister() + * will replay the notifications if the params are added/removed + * outside of the lifetime of the instance. + */ + if (!devl_is_registered(devlink)) + return; msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); if (!msg) @@ -6314,12 +6324,10 @@ static int devlink_nl_cmd_region_read_dumpit(struct sk_buff *skb, start_offset = state->start_offset; - devlink = devlink_get_from_attrs(sock_net(cb->skb->sk), attrs); + devlink = devlink_get_from_attrs_lock(sock_net(cb->skb->sk), attrs); if (IS_ERR(devlink)) return PTR_ERR(devlink); - devl_lock(devlink); - if (!attrs[DEVLINK_ATTR_REGION_NAME]) { NL_SET_ERR_MSG(cb->extack, "No region name provided"); err = -EINVAL; @@ -7735,9 +7743,10 @@ devlink_health_reporter_get_from_cb(struct netlink_callback *cb) struct nlattr **attrs = info->attrs; struct devlink *devlink; - devlink = devlink_get_from_attrs(sock_net(cb->skb->sk), attrs); + devlink = devlink_get_from_attrs_lock(sock_net(cb->skb->sk), attrs); if (IS_ERR(devlink)) return NULL; + devl_unlock(devlink); reporter = devlink_health_reporter_get_from_attrs(devlink, attrs); devlink_put(devlink); @@ -7810,6 +7819,12 @@ devlink_nl_cmd_health_reporter_get_dumpit(struct sk_buff *msg, int idx = 0; mutex_lock(&devlink->reporters_lock); + if (!devl_is_registered(devlink)) { + mutex_unlock(&devlink->reporters_lock); + devlink_put(devlink); + continue; + } + list_for_each_entry(reporter, &devlink->reporter_list, list) { if (idx < state->idx) { @@ -7831,6 +7846,9 @@ devlink_nl_cmd_health_reporter_get_dumpit(struct sk_buff *msg, mutex_unlock(&devlink->reporters_lock); devl_lock(devlink); + if (!devl_is_registered(devlink)) + goto next_devlink; + xa_for_each(&devlink->ports, port_index, port) { mutex_lock(&port->reporters_lock); list_for_each_entry(reporter, &port->reporter_list, list) { @@ -7854,6 +7872,7 @@ devlink_nl_cmd_health_reporter_get_dumpit(struct sk_buff *msg, } mutex_unlock(&port->reporters_lock); } +next_devlink: devl_unlock(devlink); devlink_put(devlink); } @@ -10902,8 +10921,6 @@ int devlink_params_register(struct devlink *devlink, const struct devlink_param *param = params; int i, err; - ASSERT_DEVLINK_NOT_REGISTERED(devlink); - for (i = 0; i < params_count; i++, param++) { err = devlink_param_register(devlink, param); if (err) @@ -10934,8 +10951,6 @@ void devlink_params_unregister(struct devlink *devlink, const struct devlink_param *param = params; int i; - ASSERT_DEVLINK_NOT_REGISTERED(devlink); - for (i = 0; i < params_count; i++, param++) devlink_param_unregister(devlink, param); } @@ -10955,8 +10970,6 @@ int devlink_param_register(struct devlink *devlink, { struct devlink_param_item *param_item; - ASSERT_DEVLINK_NOT_REGISTERED(devlink); - WARN_ON(devlink_param_verify(param)); WARN_ON(devlink_param_find_by_name(&devlink->param_list, param->name)); @@ -10972,6 +10985,7 @@ int devlink_param_register(struct devlink *devlink, param_item->param = param; list_add_tail(¶m_item->list, &devlink->param_list); + devlink_param_notify(devlink, 0, param_item, DEVLINK_CMD_PARAM_NEW); return 0; } EXPORT_SYMBOL_GPL(devlink_param_register); @@ -10986,11 +11000,10 @@ void devlink_param_unregister(struct devlink *devlink, { struct devlink_param_item *param_item; - ASSERT_DEVLINK_NOT_REGISTERED(devlink); - param_item = devlink_param_find_by_name(&devlink->param_list, param->name); WARN_ON(!param_item); + devlink_param_notify(devlink, 0, param_item, DEVLINK_CMD_PARAM_DEL); list_del(¶m_item->list); kfree(param_item); } @@ -11050,8 +11063,6 @@ int devlink_param_driverinit_value_set(struct devlink *devlink, u32 param_id, { struct devlink_param_item *param_item; - ASSERT_DEVLINK_NOT_REGISTERED(devlink); - param_item = devlink_param_find_by_id(&devlink->param_list, param_id); if (!param_item) return -EINVAL; @@ -11065,6 +11076,8 @@ int devlink_param_driverinit_value_set(struct devlink *devlink, u32 param_id, else param_item->driverinit_value = init_val; param_item->driverinit_value_valid = true; + + devlink_param_notify(devlink, 0, param_item, DEVLINK_CMD_PARAM_NEW); return 0; } EXPORT_SYMBOL_GPL(devlink_param_driverinit_value_set); @@ -12219,7 +12232,8 @@ void devlink_compat_running_version(struct devlink *devlink, return; devl_lock(devlink); - __devlink_compat_running_version(devlink, buf, len); + if (devl_is_registered(devlink)) + __devlink_compat_running_version(devlink, buf, len); devl_unlock(devlink); } @@ -12228,20 +12242,28 @@ int devlink_compat_flash_update(struct devlink *devlink, const char *file_name) struct devlink_flash_update_params params = {}; int ret; - if (!devlink->ops->flash_update) - return -EOPNOTSUPP; + devl_lock(devlink); + if (!devl_is_registered(devlink)) { + ret = -ENODEV; + goto out_unlock; + } + + if (!devlink->ops->flash_update) { + ret = -EOPNOTSUPP; + goto out_unlock; + } ret = request_firmware(¶ms.fw, file_name, devlink->dev); if (ret) - return ret; + goto out_unlock; - devl_lock(devlink); devlink_flash_update_begin_notify(devlink); ret = devlink->ops->flash_update(devlink, ¶ms, NULL); devlink_flash_update_end_notify(devlink); - devl_unlock(devlink); release_firmware(params.fw); +out_unlock: + devl_unlock(devlink); return ret; } diff --git a/net/devlink/netlink.c b/net/devlink/netlink.c index a552e723f4a6..b5b8ac6db2d1 100644 --- a/net/devlink/netlink.c +++ b/net/devlink/netlink.c @@ -82,7 +82,8 @@ static const struct nla_policy devlink_nl_policy[DEVLINK_ATTR_MAX + 1] = { [DEVLINK_ATTR_REGION_DIRECT] = { .type = NLA_FLAG }, }; -struct devlink *devlink_get_from_attrs(struct net *net, struct nlattr **attrs) +struct devlink * +devlink_get_from_attrs_lock(struct net *net, struct nlattr **attrs) { struct devlink *devlink; unsigned long index; @@ -96,9 +97,12 @@ struct devlink *devlink_get_from_attrs(struct net *net, struct nlattr **attrs) devname = nla_data(attrs[DEVLINK_ATTR_DEV_NAME]); devlinks_xa_for_each_registered_get(net, index, devlink) { - if (strcmp(devlink->dev->bus->name, busname) == 0 && + devl_lock(devlink); + if (devl_is_registered(devlink) && + strcmp(devlink->dev->bus->name, busname) == 0 && strcmp(dev_name(devlink->dev), devname) == 0) return devlink; + devl_unlock(devlink); devlink_put(devlink); } @@ -113,10 +117,10 @@ static int devlink_nl_pre_doit(const struct genl_split_ops *ops, struct devlink *devlink; int err; - devlink = devlink_get_from_attrs(genl_info_net(info), info->attrs); + devlink = devlink_get_from_attrs_lock(genl_info_net(info), info->attrs); if (IS_ERR(devlink)) return PTR_ERR(devlink); - devl_lock(devlink); + info->user_ptr[0] = devlink; if (ops->internal_flags & DEVLINK_NL_FLAG_NEED_PORT) { devlink_port = devlink_port_get_from_info(devlink, info); @@ -208,7 +212,12 @@ int devlink_nl_instance_iter_dump(struct sk_buff *msg, devlink_dump_for_each_instance_get(msg, state, devlink) { devl_lock(devlink); - err = cmd->dump_one(msg, devlink, cb); + + if (devl_is_registered(devlink)) + err = cmd->dump_one(msg, devlink, cb); + else + err = 0; + devl_unlock(devlink); devlink_put(devlink);