2018-05-31 19:14:40 +03:00
// SPDX-License-Identifier: GPL-2.0
# define _GNU_SOURCE
# include <arpa/inet.h>
# include <errno.h>
# include <error.h>
# include <fcntl.h>
# include <limits.h>
# include <linux/filter.h>
# include <linux/bpf.h>
# include <linux/if_packet.h>
# include <linux/if_vlan.h>
# include <linux/virtio_net.h>
# include <net/if.h>
# include <net/ethernet.h>
# include <netinet/ip.h>
# include <netinet/udp.h>
# include <poll.h>
# include <sched.h>
# include <stdbool.h>
# include <stdint.h>
# include <stdio.h>
# include <stdlib.h>
# include <string.h>
# include <sys/mman.h>
# include <sys/socket.h>
# include <sys/stat.h>
# include <sys/types.h>
# include <unistd.h>
# include "psock_lib.h"
static bool cfg_use_bind ;
static bool cfg_use_csum_off ;
static bool cfg_use_csum_off_bad ;
static bool cfg_use_dgram ;
static bool cfg_use_gso ;
static bool cfg_use_qdisc_bypass ;
static bool cfg_use_vlan ;
static bool cfg_use_vnet ;
static char * cfg_ifname = " lo " ;
static int cfg_mtu = 1500 ;
static int cfg_payload_len = DATA_LEN ;
static int cfg_truncate_len = INT_MAX ;
static uint16_t cfg_port = 8000 ;
/* test sending up to max mtu + 1 */
# define TEST_SZ (sizeof(struct virtio_net_hdr) + ETH_HLEN + ETH_MAX_MTU + 1)
static char tbuf [ TEST_SZ ] , rbuf [ TEST_SZ ] ;
static unsigned long add_csum_hword ( const uint16_t * start , int num_u16 )
{
unsigned long sum = 0 ;
int i ;
for ( i = 0 ; i < num_u16 ; i + + )
sum + = start [ i ] ;
return sum ;
}
static uint16_t build_ip_csum ( const uint16_t * start , int num_u16 ,
unsigned long sum )
{
sum + = add_csum_hword ( start , num_u16 ) ;
while ( sum > > 16 )
sum = ( sum & 0xffff ) + ( sum > > 16 ) ;
return ~ sum ;
}
static int build_vnet_header ( void * header )
{
struct virtio_net_hdr * vh = header ;
vh - > hdr_len = ETH_HLEN + sizeof ( struct iphdr ) + sizeof ( struct udphdr ) ;
if ( cfg_use_csum_off ) {
vh - > flags | = VIRTIO_NET_HDR_F_NEEDS_CSUM ;
vh - > csum_start = ETH_HLEN + sizeof ( struct iphdr ) ;
vh - > csum_offset = __builtin_offsetof ( struct udphdr , check ) ;
/* position check field exactly one byte beyond end of packet */
if ( cfg_use_csum_off_bad )
vh - > csum_start + = sizeof ( struct udphdr ) + cfg_payload_len -
vh - > csum_offset - 1 ;
}
if ( cfg_use_gso ) {
vh - > gso_type = VIRTIO_NET_HDR_GSO_UDP ;
vh - > gso_size = cfg_mtu - sizeof ( struct iphdr ) ;
}
return sizeof ( * vh ) ;
}
static int build_eth_header ( void * header )
{
struct ethhdr * eth = header ;
if ( cfg_use_vlan ) {
uint16_t * tag = header + ETH_HLEN ;
eth - > h_proto = htons ( ETH_P_8021Q ) ;
tag [ 1 ] = htons ( ETH_P_IP ) ;
return ETH_HLEN + 4 ;
}
eth - > h_proto = htons ( ETH_P_IP ) ;
return ETH_HLEN ;
}
static int build_ipv4_header ( void * header , int payload_len )
{
struct iphdr * iph = header ;
iph - > ihl = 5 ;
iph - > version = 4 ;
iph - > ttl = 8 ;
iph - > tot_len = htons ( sizeof ( * iph ) + sizeof ( struct udphdr ) + payload_len ) ;
iph - > id = htons ( 1337 ) ;
iph - > protocol = IPPROTO_UDP ;
iph - > saddr = htonl ( ( 172 < < 24 ) | ( 17 < < 16 ) | 2 ) ;
iph - > daddr = htonl ( ( 172 < < 24 ) | ( 17 < < 16 ) | 1 ) ;
iph - > check = build_ip_csum ( ( void * ) iph , iph - > ihl < < 1 , 0 ) ;
return iph - > ihl < < 2 ;
}
static int build_udp_header ( void * header , int payload_len )
{
const int alen = sizeof ( uint32_t ) ;
struct udphdr * udph = header ;
int len = sizeof ( * udph ) + payload_len ;
udph - > source = htons ( 9 ) ;
udph - > dest = htons ( cfg_port ) ;
udph - > len = htons ( len ) ;
if ( cfg_use_csum_off )
udph - > check = build_ip_csum ( header - ( 2 * alen ) , alen ,
htons ( IPPROTO_UDP ) + udph - > len ) ;
else
udph - > check = 0 ;
return sizeof ( * udph ) ;
}
static int build_packet ( int payload_len )
{
int off = 0 ;
off + = build_vnet_header ( tbuf ) ;
off + = build_eth_header ( tbuf + off ) ;
off + = build_ipv4_header ( tbuf + off , payload_len ) ;
off + = build_udp_header ( tbuf + off , payload_len ) ;
if ( off + payload_len > sizeof ( tbuf ) )
error ( 1 , 0 , " payload length exceeds max " ) ;
memset ( tbuf + off , DATA_CHAR , payload_len ) ;
return off + payload_len ;
}
static void do_bind ( int fd )
{
struct sockaddr_ll laddr = { 0 } ;
laddr . sll_family = AF_PACKET ;
laddr . sll_protocol = htons ( ETH_P_IP ) ;
laddr . sll_ifindex = if_nametoindex ( cfg_ifname ) ;
if ( ! laddr . sll_ifindex )
error ( 1 , errno , " if_nametoindex " ) ;
if ( bind ( fd , ( void * ) & laddr , sizeof ( laddr ) ) )
error ( 1 , errno , " bind " ) ;
}
static void do_send ( int fd , char * buf , int len )
{
int ret ;
if ( ! cfg_use_vnet ) {
buf + = sizeof ( struct virtio_net_hdr ) ;
len - = sizeof ( struct virtio_net_hdr ) ;
}
if ( cfg_use_dgram ) {
buf + = ETH_HLEN ;
len - = ETH_HLEN ;
}
if ( cfg_use_bind ) {
ret = write ( fd , buf , len ) ;
} else {
struct sockaddr_ll laddr = { 0 } ;
laddr . sll_protocol = htons ( ETH_P_IP ) ;
laddr . sll_ifindex = if_nametoindex ( cfg_ifname ) ;
if ( ! laddr . sll_ifindex )
error ( 1 , errno , " if_nametoindex " ) ;
ret = sendto ( fd , buf , len , 0 , ( void * ) & laddr , sizeof ( laddr ) ) ;
}
if ( ret = = - 1 )
error ( 1 , errno , " write " ) ;
if ( ret ! = len )
error ( 1 , 0 , " write: %u %u " , ret , len ) ;
fprintf ( stderr , " tx: %u \n " , ret ) ;
}
static int do_tx ( void )
{
const int one = 1 ;
int fd , len ;
fd = socket ( PF_PACKET , cfg_use_dgram ? SOCK_DGRAM : SOCK_RAW , 0 ) ;
if ( fd = = - 1 )
error ( 1 , errno , " socket t " ) ;
if ( cfg_use_bind )
do_bind ( fd ) ;
if ( cfg_use_qdisc_bypass & &
setsockopt ( fd , SOL_PACKET , PACKET_QDISC_BYPASS , & one , sizeof ( one ) ) )
error ( 1 , errno , " setsockopt qdisc bypass " ) ;
if ( cfg_use_vnet & &
setsockopt ( fd , SOL_PACKET , PACKET_VNET_HDR , & one , sizeof ( one ) ) )
error ( 1 , errno , " setsockopt vnet " ) ;
len = build_packet ( cfg_payload_len ) ;
if ( cfg_truncate_len < len )
len = cfg_truncate_len ;
do_send ( fd , tbuf , len ) ;
if ( close ( fd ) )
error ( 1 , errno , " close t " ) ;
return len ;
}
static int setup_rx ( void )
{
struct timeval tv = { . tv_usec = 100 * 1000 } ;
struct sockaddr_in raddr = { 0 } ;
int fd ;
fd = socket ( PF_INET , SOCK_DGRAM , 0 ) ;
if ( fd = = - 1 )
error ( 1 , errno , " socket r " ) ;
if ( setsockopt ( fd , SOL_SOCKET , SO_RCVTIMEO , & tv , sizeof ( tv ) ) )
error ( 1 , errno , " setsockopt rcv timeout " ) ;
raddr . sin_family = AF_INET ;
raddr . sin_port = htons ( cfg_port ) ;
raddr . sin_addr . s_addr = htonl ( INADDR_ANY ) ;
if ( bind ( fd , ( void * ) & raddr , sizeof ( raddr ) ) )
error ( 1 , errno , " bind r " ) ;
return fd ;
}
static void do_rx ( int fd , int expected_len , char * expected )
{
int ret ;
ret = recv ( fd , rbuf , sizeof ( rbuf ) , 0 ) ;
if ( ret = = - 1 )
error ( 1 , errno , " recv " ) ;
if ( ret ! = expected_len )
error ( 1 , 0 , " recv: %u != %u " , ret , expected_len ) ;
if ( memcmp ( rbuf , expected , ret ) )
error ( 1 , 0 , " recv: data mismatch " ) ;
fprintf ( stderr , " rx: %u \n " , ret ) ;
}
static int setup_sniffer ( void )
{
struct timeval tv = { . tv_usec = 100 * 1000 } ;
int fd ;
fd = socket ( PF_PACKET , SOCK_RAW , 0 ) ;
if ( fd = = - 1 )
error ( 1 , errno , " socket p " ) ;
if ( setsockopt ( fd , SOL_SOCKET , SO_RCVTIMEO , & tv , sizeof ( tv ) ) )
error ( 1 , errno , " setsockopt rcv timeout " ) ;
pair_udp_setfilter ( fd ) ;
do_bind ( fd ) ;
return fd ;
}
static void parse_opts ( int argc , char * * argv )
{
int c ;
while ( ( c = getopt ( argc , argv , " bcCdgl:qt:vV " ) ) ! = - 1 ) {
switch ( c ) {
case ' b ' :
cfg_use_bind = true ;
break ;
case ' c ' :
cfg_use_csum_off = true ;
break ;
case ' C ' :
cfg_use_csum_off_bad = true ;
break ;
case ' d ' :
cfg_use_dgram = true ;
break ;
case ' g ' :
cfg_use_gso = true ;
break ;
case ' l ' :
cfg_payload_len = strtoul ( optarg , NULL , 0 ) ;
break ;
case ' q ' :
cfg_use_qdisc_bypass = true ;
break ;
case ' t ' :
cfg_truncate_len = strtoul ( optarg , NULL , 0 ) ;
break ;
case ' v ' :
cfg_use_vnet = true ;
break ;
case ' V ' :
cfg_use_vlan = true ;
break ;
default :
error ( 1 , 0 , " %s: parse error " , argv [ 0 ] ) ;
}
}
if ( cfg_use_vlan & & cfg_use_dgram )
error ( 1 , 0 , " option vlan (-V) conflicts with dgram (-d) " ) ;
if ( cfg_use_csum_off & & ! cfg_use_vnet )
error ( 1 , 0 , " option csum offload (-c) requires vnet (-v) " ) ;
if ( cfg_use_csum_off_bad & & ! cfg_use_csum_off )
error ( 1 , 0 , " option csum bad (-C) requires csum offload (-c) " ) ;
if ( cfg_use_gso & & ! cfg_use_csum_off )
error ( 1 , 0 , " option gso (-g) requires csum offload (-c) " ) ;
}
static void run_test ( void )
{
int fdr , fds , total_len ;
fdr = setup_rx ( ) ;
fds = setup_sniffer ( ) ;
total_len = do_tx ( ) ;
/* BPF filter accepts only this length, vlan changes MAC */
if ( cfg_payload_len = = DATA_LEN & & ! cfg_use_vlan )
do_rx ( fds , total_len - sizeof ( struct virtio_net_hdr ) ,
tbuf + sizeof ( struct virtio_net_hdr ) ) ;
do_rx ( fdr , cfg_payload_len , tbuf + total_len - cfg_payload_len ) ;
if ( close ( fds ) )
error ( 1 , errno , " close s " ) ;
if ( close ( fdr ) )
error ( 1 , errno , " close r " ) ;
}
int main ( int argc , char * * argv )
{
parse_opts ( argc , argv ) ;
if ( system ( " ip link set dev lo mtu 1500 " ) )
error ( 1 , errno , " ip link set mtu " ) ;
if ( system ( " ip addr add dev lo 172.17.0.1/24 " ) )
error ( 1 , errno , " ip addr add " ) ;
2022-05-25 06:18:19 +03:00
if ( system ( " sysctl -w net.ipv4.conf.lo.accept_local=1 " ) )
error ( 1 , errno , " sysctl lo.accept_local " ) ;
2018-05-31 19:14:40 +03:00
run_test ( ) ;
fprintf ( stderr , " OK \n \n " ) ;
return 0 ;
}