diff --git a/include/net/netfilter/nf_tables.h b/include/net/netfilter/nf_tables.h index 1e6e4af4df0a..3ff6b3362800 100644 --- a/include/net/netfilter/nf_tables.h +++ b/include/net/netfilter/nf_tables.h @@ -567,6 +567,11 @@ static inline void *nft_set_priv(const struct nft_set *set) return (void *)set->data; } +static inline enum nft_data_types nft_set_datatype(const struct nft_set *set) +{ + return set->dtype == NFT_DATA_VERDICT ? NFT_DATA_VERDICT : NFT_DATA_VALUE; +} + static inline bool nft_set_gc_is_pending(const struct nft_set *s) { return refcount_read(&s->refs) != 1; diff --git a/net/netfilter/nf_tables_api.c b/net/netfilter/nf_tables_api.c index 3999b89793fc..506dc5c4cdcc 100644 --- a/net/netfilter/nf_tables_api.c +++ b/net/netfilter/nf_tables_api.c @@ -5328,8 +5328,7 @@ static int nf_tables_fill_setelem(struct sk_buff *skb, if (nft_set_ext_exists(ext, NFT_SET_EXT_DATA) && nft_data_dump(skb, NFTA_SET_ELEM_DATA, nft_set_ext_data(ext), - set->dtype == NFT_DATA_VERDICT ? NFT_DATA_VERDICT : NFT_DATA_VALUE, - set->dlen) < 0) + nft_set_datatype(set), set->dlen) < 0) goto nla_put_failure; if (nft_set_ext_exists(ext, NFT_SET_EXT_EXPRESSIONS) && @@ -10249,6 +10248,9 @@ static int nft_validate_register_store(const struct nft_ctx *ctx, return 0; default: + if (type != NFT_DATA_VALUE) + return -EINVAL; + if (reg < NFT_REG_1 * NFT_REG_SIZE / NFT_REG32_SIZE) return -EINVAL; if (len == 0) @@ -10257,8 +10259,6 @@ static int nft_validate_register_store(const struct nft_ctx *ctx, sizeof_field(struct nft_regs, data)) return -ERANGE; - if (data != NULL && type != NFT_DATA_VALUE) - return -EINVAL; return 0; } } diff --git a/net/netfilter/nft_lookup.c b/net/netfilter/nft_lookup.c index 9d18c5428d53..b9df27c2718b 100644 --- a/net/netfilter/nft_lookup.c +++ b/net/netfilter/nft_lookup.c @@ -136,7 +136,8 @@ static int nft_lookup_init(const struct nft_ctx *ctx, return -EINVAL; err = nft_parse_register_store(ctx, tb[NFTA_LOOKUP_DREG], - &priv->dreg, NULL, set->dtype, + &priv->dreg, NULL, + nft_set_datatype(set), set->dlen); if (err < 0) return err;