diff --git a/include/net/mctpdevice.h b/include/net/mctpdevice.h index 71a11012fac7..3a439463f055 100644 --- a/include/net/mctpdevice.h +++ b/include/net/mctpdevice.h @@ -17,6 +17,8 @@ struct mctp_dev { struct net_device *dev; + refcount_t refs; + unsigned int net; /* Only modified under RTNL. Reads have addrs_lock held */ @@ -32,4 +34,7 @@ struct mctp_dev { struct mctp_dev *mctp_dev_get_rtnl(const struct net_device *dev); struct mctp_dev *__mctp_dev_get(const struct net_device *dev); +void mctp_dev_hold(struct mctp_dev *mdev); +void mctp_dev_put(struct mctp_dev *mdev); + #endif /* __NET_MCTPDEVICE_H */ diff --git a/net/mctp/device.c b/net/mctp/device.c index c34963974cc1..f556c6d01abc 100644 --- a/net/mctp/device.c +++ b/net/mctp/device.c @@ -35,14 +35,6 @@ struct mctp_dev *mctp_dev_get_rtnl(const struct net_device *dev) return rtnl_dereference(dev->mctp_ptr); } -static void mctp_dev_destroy(struct mctp_dev *mdev) -{ - struct net_device *dev = mdev->dev; - - dev_put(dev); - kfree_rcu(mdev, rcu); -} - static int mctp_fill_addrinfo(struct sk_buff *skb, struct netlink_callback *cb, struct mctp_dev *mdev, mctp_eid_t eid) { @@ -255,6 +247,19 @@ static int mctp_rtm_deladdr(struct sk_buff *skb, struct nlmsghdr *nlh, return 0; } +void mctp_dev_hold(struct mctp_dev *mdev) +{ + refcount_inc(&mdev->refs); +} + +void mctp_dev_put(struct mctp_dev *mdev) +{ + if (refcount_dec_and_test(&mdev->refs)) { + dev_put(mdev->dev); + kfree_rcu(mdev, rcu); + } +} + static struct mctp_dev *mctp_add_dev(struct net_device *dev) { struct mctp_dev *mdev; @@ -270,7 +275,9 @@ static struct mctp_dev *mctp_add_dev(struct net_device *dev) mdev->net = mctp_default_net(dev_net(dev)); /* associate to net_device */ + refcount_set(&mdev->refs, 1); rcu_assign_pointer(dev->mctp_ptr, mdev); + dev_hold(dev); mdev->dev = dev; @@ -345,7 +352,7 @@ static void mctp_unregister(struct net_device *dev) mctp_neigh_remove_dev(mdev); kfree(mdev->addrs); - mctp_dev_destroy(mdev); + mctp_dev_put(mdev); } static int mctp_register(struct net_device *dev) diff --git a/net/mctp/neigh.c b/net/mctp/neigh.c index 90ed2f02d1fb..5cc042121493 100644 --- a/net/mctp/neigh.c +++ b/net/mctp/neigh.c @@ -47,7 +47,7 @@ static int mctp_neigh_add(struct mctp_dev *mdev, mctp_eid_t eid, } INIT_LIST_HEAD(&neigh->list); neigh->dev = mdev; - dev_hold(neigh->dev->dev); + mctp_dev_hold(neigh->dev); neigh->eid = eid; neigh->source = source; memcpy(neigh->ha, lladdr, lladdr_len); @@ -63,7 +63,7 @@ static void __mctp_neigh_free(struct rcu_head *rcu) { struct mctp_neigh *neigh = container_of(rcu, struct mctp_neigh, rcu); - dev_put(neigh->dev->dev); + mctp_dev_put(neigh->dev); kfree(neigh); } diff --git a/net/mctp/route.c b/net/mctp/route.c index b2243b150e71..37aa67847a5a 100644 --- a/net/mctp/route.c +++ b/net/mctp/route.c @@ -455,7 +455,7 @@ static int mctp_route_output(struct mctp_route *route, struct sk_buff *skb) static void mctp_route_release(struct mctp_route *rt) { if (refcount_dec_and_test(&rt->refs)) { - dev_put(rt->dev->dev); + mctp_dev_put(rt->dev); kfree_rcu(rt, rcu); } } @@ -815,7 +815,7 @@ static int mctp_route_add(struct mctp_dev *mdev, mctp_eid_t daddr_start, rt->max = daddr_start + daddr_extent; rt->mtu = mtu; rt->dev = mdev; - dev_hold(rt->dev->dev); + mctp_dev_hold(rt->dev); rt->type = type; rt->output = rtfn;