diff --git a/include/net/netfilter/nf_queue.h b/include/net/netfilter/nf_queue.h index 47088083667b..c204af20c27e 100644 --- a/include/net/netfilter/nf_queue.h +++ b/include/net/netfilter/nf_queue.h @@ -34,7 +34,7 @@ void nf_register_queue_handler(struct net *net, const struct nf_queue_handler *q void nf_unregister_queue_handler(struct net *net); void nf_reinject(struct nf_queue_entry *entry, unsigned int verdict); -void nf_queue_entry_get_refs(struct nf_queue_entry *entry); +bool nf_queue_entry_get_refs(struct nf_queue_entry *entry); void nf_queue_entry_release_refs(struct nf_queue_entry *entry); static inline void init_hashrandom(u32 *jhash_initval) diff --git a/net/netfilter/nf_queue.c b/net/netfilter/nf_queue.c index ee3b57c25a6a..643dbfe7c581 100644 --- a/net/netfilter/nf_queue.c +++ b/net/netfilter/nf_queue.c @@ -108,18 +108,20 @@ static void nf_queue_entry_get_br_nf_refs(struct sk_buff *skb) } /* Bump dev refs so they don't vanish while packet is out */ -void nf_queue_entry_get_refs(struct nf_queue_entry *entry) +bool nf_queue_entry_get_refs(struct nf_queue_entry *entry) { struct nf_hook_state *state = &entry->state; + if (state->sk && !refcount_inc_not_zero(&state->sk->sk_refcnt)) + return false; + if (state->in) dev_hold(state->in); if (state->out) dev_hold(state->out); - if (state->sk) - sock_hold(state->sk); nf_queue_entry_get_br_nf_refs(entry->skb); + return true; } EXPORT_SYMBOL_GPL(nf_queue_entry_get_refs); @@ -210,7 +212,10 @@ static int __nf_queue(struct sk_buff *skb, const struct nf_hook_state *state, .size = sizeof(*entry) + route_key_size, }; - nf_queue_entry_get_refs(entry); + if (!nf_queue_entry_get_refs(entry)) { + kfree(entry); + return -ENOTCONN; + } switch (entry->state.pf) { case AF_INET: diff --git a/net/netfilter/nfnetlink_queue.c b/net/netfilter/nfnetlink_queue.c index ca21f8f4a47c..7d3ab08a5a2d 100644 --- a/net/netfilter/nfnetlink_queue.c +++ b/net/netfilter/nfnetlink_queue.c @@ -712,9 +712,15 @@ static struct nf_queue_entry * nf_queue_entry_dup(struct nf_queue_entry *e) { struct nf_queue_entry *entry = kmemdup(e, e->size, GFP_ATOMIC); - if (entry) - nf_queue_entry_get_refs(entry); - return entry; + + if (!entry) + return NULL; + + if (nf_queue_entry_get_refs(entry)) + return entry; + + kfree(entry); + return NULL; } #if IS_ENABLED(CONFIG_BRIDGE_NETFILTER)