diff --git a/include/net/ip6_route.h b/include/net/ip6_route.h index 877f682989b8..295d291269e2 100644 --- a/include/net/ip6_route.h +++ b/include/net/ip6_route.h @@ -64,8 +64,16 @@ static inline bool rt6_need_strict(const struct in6_addr *daddr) void ip6_route_input(struct sk_buff *skb); -struct dst_entry *ip6_route_output(struct net *net, const struct sock *sk, - struct flowi6 *fl6); +struct dst_entry *ip6_route_output_flags(struct net *net, const struct sock *sk, + struct flowi6 *fl6, int flags); + +static inline struct dst_entry *ip6_route_output(struct net *net, + const struct sock *sk, + struct flowi6 *fl6) +{ + return ip6_route_output_flags(net, sk, fl6, 0); +} + struct dst_entry *ip6_route_lookup(struct net *net, struct flowi6 *fl6, int flags); diff --git a/net/ipv6/ip6_output.c b/net/ipv6/ip6_output.c index 23de98f976d5..a163102f1803 100644 --- a/net/ipv6/ip6_output.c +++ b/net/ipv6/ip6_output.c @@ -909,6 +909,7 @@ static int ip6_dst_lookup_tail(struct net *net, const struct sock *sk, struct rt6_info *rt; #endif int err; + int flags = 0; /* The correct way to handle this would be to do * ip6_route_get_saddr, and then ip6_route_output; however, @@ -940,10 +941,13 @@ static int ip6_dst_lookup_tail(struct net *net, const struct sock *sk, dst_release(*dst); *dst = NULL; } + + if (fl6->flowi6_oif) + flags |= RT6_LOOKUP_F_IFACE; } if (!*dst) - *dst = ip6_route_output(net, sk, fl6); + *dst = ip6_route_output_flags(net, sk, fl6, flags); err = (*dst)->error; if (err) diff --git a/net/ipv6/route.c b/net/ipv6/route.c index 3c8834bc822d..ed446639219c 100644 --- a/net/ipv6/route.c +++ b/net/ipv6/route.c @@ -1183,11 +1183,10 @@ static struct rt6_info *ip6_pol_route_output(struct net *net, struct fib6_table return ip6_pol_route(net, table, fl6->flowi6_oif, fl6, flags); } -struct dst_entry *ip6_route_output(struct net *net, const struct sock *sk, - struct flowi6 *fl6) +struct dst_entry *ip6_route_output_flags(struct net *net, const struct sock *sk, + struct flowi6 *fl6, int flags) { struct dst_entry *dst; - int flags = 0; bool any_src; dst = l3mdev_rt6_dst_by_oif(net, fl6); @@ -1208,7 +1207,7 @@ struct dst_entry *ip6_route_output(struct net *net, const struct sock *sk, return fib6_rule_lookup(net, fl6, flags, ip6_pol_route_output); } -EXPORT_SYMBOL(ip6_route_output); +EXPORT_SYMBOL_GPL(ip6_route_output_flags); struct dst_entry *ip6_blackhole_route(struct net *net, struct dst_entry *dst_orig) {