diff --git a/server/src/tests/core/util/test_net.py b/server/src/tests/core/util/test_net.py index 3b2a8b366..66258063e 100644 --- a/server/src/tests/core/util/test_net.py +++ b/server/src/tests/core/util/test_net.py @@ -29,6 +29,7 @@ """ Author: Adolfo Gómez, dkmaster at dkmon dot com """ +import ipaddress import logging @@ -40,7 +41,7 @@ logger = logging.getLogger(__name__) class NetTest(UDSTestCase): - def testNetworkFromStringIPv4(self) -> None: + def test_network_from_string_ipv4(self) -> None: for n in ( ('*', 0, 4294967295), ('192.168.0.1', 3232235521, 3232235521), @@ -100,22 +101,10 @@ class NetTest(UDSTestCase): self.assertTrue(net.contains('192.168.0.0/24', '192.168.0.{}'.format(n2))) for n3 in range(4294): - self.assertTrue( - net.contains([net.NetworkType(0, 4294967295, 4)], n3 * 1000) - ) - self.assertTrue( - net.contains(net.NetworkType(0, 4294967295, 4), n3 * 1000) - ) + self.assertTrue(net.contains([net.NetworkType(0, 4294967295, 4)], n3 * 1000)) + self.assertTrue(net.contains(net.NetworkType(0, 4294967295, 4), n3 * 1000)) - # Test some ip conversions from long to ip and viceversa - for n4 in ('172', '192', '10'): - for n5 in range(0, 256, 17): - for n6 in range(0, 256, 13): - for n7 in range(0, 256, 11): - ip = '{0}.{1}.{2}.{3}'.format(n4, n5, n6, n7) - self.assertEqual(net.long_to_ip(net.ip_to_long(ip).ip, 4), ip) - - def testNetworkFromStringIPv6(self) -> None: + def test_network_from_string_ipv6(self) -> None: # IPv6 only support standard notation, and '*', but not "netmask" or "range" for n in ( ( @@ -149,9 +138,7 @@ class NetTest(UDSTestCase): 338620831926207318622244848606417780735, ), ): - multiple_net: list[net.NetworkType] = net.networks_from_str( - n[0], version=(6 if n[0] == '*' else 0) - ) + multiple_net: list[net.NetworkType] = net.networks_from_str(n[0], version=(6 if n[0] == '*' else 0)) self.assertEqual( len(multiple_net), 1, @@ -168,9 +155,7 @@ class NetTest(UDSTestCase): 'Incorrect network end value for {0}'.format(n[0]), ) - single_net: net.NetworkType = net.network_from_str( - n[0], version=(6 if n[0] == '*' else 0) - ) + single_net: net.NetworkType = net.network_from_str(n[0], version=(6 if n[0] == '*' else 0)) self.assertEqual( len(single_net), 3, @@ -187,8 +172,31 @@ class NetTest(UDSTestCase): 'Incorrect network end value for {0}'.format(n[0]), ) + def test_ip_to_long(self) -> None: + for i in range(32): + ipv4 = f'192.168.{i}.{i*2}' + ipv4_num = int(ipaddress.IPv4Address(ipv4)) + ipv6 = f'2001:{0xdb8+i:04x}::{i:02x}' + ipv6_num = int(ipaddress.IPv6Address(ipv6)) + + self.assertEqual(net.ip_to_long(ipv4).ip, ipv4_num) + self.assertEqual(net.ip_to_long(ipv4).version, 4) + self.assertEqual(net.ip_to_long(f'::ffff:{ipv4}').ip, ipv4_num) + self.assertEqual(net.ip_to_long(f'::ffff:{ipv4}').version, 4) + + self.assertEqual(net.ip_to_long(ipv6).ip, ipv6_num) + self.assertEqual(net.ip_to_long(ipv6).version, 6) + + # Test some ip conversions from long to ip and viceversa + for a in ('172', '192', '10'): + for b in range(0, 256, 17): + for ipv6 in range(0, 256, 13): + for d in range(0, 256, 11): + ipv4 = '{0}.{1}.{2}.{3}'.format(a, b, ipv6, d) + self.assertEqual(net.long_to_ip(net.ip_to_long(ipv4).ip, 4), ipv4) + # iterate some ipv6 addresses - for n6 in ( + for ipv6 in ( '2001:db8::1', '2001:1::1', '2001:2:3::1', @@ -197,4 +205,4 @@ class NetTest(UDSTestCase): '2001:2:3:4:5:6:0:1', ): # Ensure converting back to string ips works - self.assertEqual(net.long_to_ip(net.ip_to_long(n6).ip, 6), n6) + self.assertEqual(net.long_to_ip(net.ip_to_long(ipv6).ip, 6), ipv6) diff --git a/server/src/uds/core/util/net.py b/server/src/uds/core/util/net.py index 136c9bb4f..123ed6776 100644 --- a/server/src/uds/core/util/net.py +++ b/server/src/uds/core/util/net.py @@ -74,12 +74,11 @@ def ip_to_long(ip: str) -> IpType: """ # First, check if it's an ipv6 address try: - if ':' in ip and '.' not in ip: - return IpType(int(ipaddress.IPv6Address(ip)), 6) - if ':' in ip and '.' in ip: - ip = ip.split(':')[ - -1 - ] # Last part of ipv6 address is ipv4 address (has dots and colons, so we can't use ipaddress) + if ':' in ip: + if '.' in ip: # Is , for example, '::ffff:172.27.0.1' + ip = ip.split(':')[-1] + else: + return IpType(int(ipaddress.IPv6Address(ip)), 6) return IpType(int(ipaddress.IPv4Address(ip)), 4) except Exception as e: logger.error('Ivalid value: %s (%s)', ip, e)