diff --git a/drivers/infiniband/core/addr.c b/drivers/infiniband/core/addr.c index 97d0b36b5120..316a53f59ee8 100644 --- a/drivers/infiniband/core/addr.c +++ b/drivers/infiniband/core/addr.c @@ -450,23 +450,26 @@ static int addr6_resolve(struct sockaddr *src_sock, static int addr_resolve_neigh(const struct dst_entry *dst, const struct sockaddr *dst_in, struct rdma_dev_addr *addr, + unsigned int ndev_flags, u32 seq) { - if (dst->dev->flags & IFF_LOOPBACK) { + int ret = 0; + + if (ndev_flags & IFF_LOOPBACK) { memcpy(addr->dst_dev_addr, addr->src_dev_addr, MAX_ADDR_LEN); - return 0; + } else { + if (!(ndev_flags & IFF_NOARP)) { + /* If the device doesn't do ARP internally */ + ret = fetch_ha(dst, addr, dst_in, seq); + } } - - /* If the device doesn't do ARP internally */ - if (!(dst->dev->flags & IFF_NOARP)) - return fetch_ha(dst, addr, dst_in, seq); - - return 0; + return ret; } -static int rdma_set_src_addr(const struct dst_entry *dst, +static void copy_src_l2_addr(struct rdma_dev_addr *dev_addr, const struct sockaddr *dst_in, - struct rdma_dev_addr *dev_addr) + const struct dst_entry *dst, + const struct net_device *ndev) { int ret = 0; @@ -481,14 +484,37 @@ static int rdma_set_src_addr(const struct dst_entry *dst, * network type accordingly. */ if (has_gateway(dst, dst_in->sa_family) && - dst->dev->type != ARPHRD_INFINIBAND) + ndev->type != ARPHRD_INFINIBAND) dev_addr->network = dst_in->sa_family == AF_INET ? RDMA_NETWORK_IPV4 : RDMA_NETWORK_IPV6; else dev_addr->network = RDMA_NETWORK_IB; +} - return ret; +static int rdma_set_src_addr_rcu(struct rdma_dev_addr *dev_addr, + unsigned int *ndev_flags, + const struct sockaddr *dst_in, + const struct dst_entry *dst) +{ + struct net_device *ndev = READ_ONCE(dst->dev); + + *ndev_flags = ndev->flags; + /* A physical device must be the RDMA device to use */ + if (ndev->flags & IFF_LOOPBACK) { + /* + * RDMA (IB/RoCE, iWarp) doesn't run on lo interface or + * loopback IP address. So if route is resolved to loopback + * interface, translate that to a real ndev based on non + * loopback IP address. + */ + ndev = rdma_find_ndev_for_src_ip_rcu(dev_net(ndev), dst_in); + if (!ndev) + return -ENODEV; + } + + copy_src_l2_addr(dev_addr, dst_in, dst, ndev); + return 0; } static int addr_resolve(struct sockaddr *src_in, @@ -498,6 +524,7 @@ static int addr_resolve(struct sockaddr *src_in, u32 seq) { struct dst_entry *dst = NULL; + unsigned int ndev_flags = 0; struct rtable *rt = NULL; int ret; @@ -506,22 +533,26 @@ static int addr_resolve(struct sockaddr *src_in, return -EINVAL; } + rcu_read_lock(); if (src_in->sa_family == AF_INET) { ret = addr4_resolve(src_in, dst_in, addr, &rt); dst = &rt->dst; } else { ret = addr6_resolve(src_in, dst_in, addr, &dst); } - if (ret) + if (ret) { + rcu_read_unlock(); return ret; + } + ret = rdma_set_src_addr_rcu(addr, &ndev_flags, dst_in, dst); + rcu_read_unlock(); - ret = rdma_set_src_addr(dst, dst_in, addr); /* * Resolve neighbor destination address if requested and * only if src addr translation didn't fail. */ if (!ret && resolve_neigh) - ret = addr_resolve_neigh(dst, dst_in, addr, seq); + ret = addr_resolve_neigh(dst, dst_in, addr, ndev_flags, seq); if (src_in->sa_family == AF_INET) ip_rt_put(rt);