1
0
mirror of https://github.com/dkmstr/openuds.git synced 2024-12-23 17:34:17 +03:00

Added tests for ipv6 networks

This commit is contained in:
Adolfo Gómez García 2022-11-30 18:30:45 +01:00
parent d1e4f4d222
commit 0a41db6b53
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
12 changed files with 42 additions and 37 deletions

View File

@ -36,6 +36,9 @@ from uds.models import Network
from ...utils.test import UDSTestCase
NET_IPV4_TEMPLATE = '192.168.{}.0/24'
NET_IPV6_TEMPLATE = '2001:db8:85a3:8d3:13{:02x}::/64'
logger = logging.getLogger(__name__)
class NetworkModelTest(UDSTestCase):
@ -44,9 +47,9 @@ class NetworkModelTest(UDSTestCase):
def setUp(self) -> None:
super().setUp()
self.nets = []
for i in range(32):
for i in range(0, 255, 15):
n = Network()
n.name = f'Network {i}'
n.name = f'{i}'
if i % 2 == 0:
n.net_string = f'192.168.{i}.0/24'
else: # ipv6 net
@ -55,4 +58,19 @@ class NetworkModelTest(UDSTestCase):
self.nets.append(n)
def testNetworks(self) -> None:
pass
for n in self.nets:
i = int(n.name)
if i % 2 == 0: # ipv4 net
self.assertEqual(n.net_string, NET_IPV4_TEMPLATE.format(i))
# Test some ips in range are in net
for r in range(0, 256, 15):
self.assertTrue(n.contains(f'192.168.{i}.{r}'), f'192.168.{i}.{r} is not in {n.net_string}')
self.assertTrue(f'192.168.{i}.{r}' in n, f'192.168.{i}.{r} is not in {n.net_string}')
else: # ipv6 net
self.assertEqual(n.net_string, NET_IPV6_TEMPLATE.format(i))
# Test some ips in range are in net
for r in range(0, 65536, 255):
self.assertTrue(n.contains(f'2001:db8:85a3:8d3:13{i:02x}:{r:04x}::'), f'2001:db8:85a3:8d3:13{i:02x}:{r:04x}:: is not in {n.net_string}')
self.assertTrue(f'2001:db8:85a3:8d3:13{i:02x}:{r:04x}::' in n, f'2001:db8:85a3:8d3:13{i:02x}:{r:04x}:: is not in {n.net_string}')
self.assertTrue(n.contains(f'2001:db8:85a3:8d3:13{i:02x}:{r:04x}:{r:04x}::'), f'2001:db8:85a3:8d3:13{i:02x}:{r:04x}:{r:04x}:: is not in {n.net_string}')
self.assertTrue(f'2001:db8:85a3:8d3:13{i:02x}:{r:04x}:{r:04x}::' in n, f'2001:db8:85a3:8d3:13{i:02x}:{r:04x}:{r:04x}:: is not in {n.net_string}')

View File

@ -98,14 +98,14 @@ class NetTest(UDSTestCase):
self.assertEqual(net.ipToLong('192.168.0.5').ip, 3232235525)
self.assertEqual(net.longToIp(3232235525, 4), '192.168.0.5')
for n in range(0, 255):
self.assertTrue(net.ipInNetwork('192.168.0.{}'.format(n), '192.168.0.0/24'))
self.assertTrue(net.contains('192.168.0.0/24', '192.168.0.{}'.format(n)))
for n in range(4294):
self.assertTrue(
net.ipInNetwork(n * 1000, [net.NetworkType(0, 4294967295, 4)])
net.contains([net.NetworkType(0, 4294967295, 4)], n * 1000)
)
self.assertTrue(
net.ipInNetwork(n * 1000, net.NetworkType(0, 4294967295, 4))
net.contains(net.NetworkType(0, 4294967295, 4), n * 1000)
)
# Test some ip conversions from long to ip and viceversa

View File

@ -344,8 +344,8 @@ class Handler:
def validSource(self) -> bool:
try:
return net.ipInNetwork(
self._request.ip, GlobalConfig.ADMIN_TRUSTED_SOURCES.get(True)
return net.contains(
GlobalConfig.ADMIN_TRUSTED_SOURCES.get(True), self._request.ip
)
except Exception as e:
logger.warning(

View File

@ -89,7 +89,7 @@ class IPAuth(auths.Authenticator):
# The ranges are stored in group names
for g in groupsManager.getGroupsNames():
try:
if net.ipInNetwork(username, g):
if net.contains(g, username):
groupsManager.validate(g)
except Exception as e:
logger.error('Invalid network for IP auth: %s', e)
@ -113,7 +113,7 @@ class IPAuth(auths.Authenticator):
"""
validNets = self.visibleFromNets.value.strip()
# If has networks and not in any of them, not visible
if validNets and not net.ipInNetwork(request.ip, validNets):
if validNets and not net.contains(request.ip, validNets):
return False
return super().isAccesibleFrom(request)

View File

@ -178,7 +178,7 @@ def webLoginRequired(
# Helper for checking if requests is from trusted source
def isTrustedSource(ip: str) -> bool:
return net.ipInNetwork(ip, GlobalConfig.TRUSTED_SOURCES.get(True))
return net.contains(ip, GlobalConfig.TRUSTED_SOURCES.get(True))
# Decorator to protect pages that needs to be accessed from "trusted sites"

View File

@ -216,7 +216,7 @@ def networkFromString(
def networksFromString(
strNets: str,
nets: str,
version: typing.Literal[0, 4, 6] = 0,
) -> typing.List[NetworkType]:
"""
@ -224,15 +224,15 @@ def networksFromString(
Returns a list of networks tuples in the form [(start1, end1), (start2, end2) ...]
"""
res = []
for strNet in re.split('[;,]', strNets):
for strNet in re.split('[;,]', nets):
if strNet:
res.append(networkFromString(strNet, version))
return res
def ipInNetwork(
ip: typing.Union[str, int],
def contains(
networks: typing.Union[str, NetworkType, typing.List[NetworkType]],
ip: typing.Union[str, int],
version: typing.Literal[0, 4, 6] = 0,
) -> bool:
if isinstance(ip, str):

View File

@ -198,7 +198,7 @@ class EmailMFA(mfas.MFA):
def checkAction(self, action: str, request: 'ExtendedHttpRequest') -> bool:
def checkIp() -> bool:
return any(
i.ipInNetwork(request.ip)
i.contains(request.ip)
for i in models.Network.objects.filter(uuid__in=self.networks.value)
)

View File

@ -166,7 +166,7 @@ class RadiusOTP(mfas.MFA):
def checkAction(self, action: str, request: 'ExtendedHttpRequest') -> bool:
def checkIp() -> bool:
return any(
i.ipInNetwork(request.ip)
i.contains(request.ip)
for i in models.Network.objects.filter(uuid__in=self.networks.value)
)

View File

@ -270,7 +270,7 @@ class SMSMFA(mfas.MFA):
def checkAction(self, action: str, request: 'ExtendedHttpRequest') -> bool:
def checkIp() -> bool:
return any(
i.ipInNetwork(request.ip)
i.contains(request.ip)
for i in models.Network.objects.filter(uuid__in=self.networks.value)
)

View File

@ -181,9 +181,9 @@ class Network(UUIDModel, TaggingMixin): # type: ignore
"""
return net.longToIp(self.net_end)
def ipInNetwork(self, ip: str) -> bool:
def contains(self, ip: str) -> bool:
"""
Returns true if the specified ip is in this network
Returns True if the specified ip is in this network
"""
# if net_string is '*', then we are in all networks, return true
if self.net_string == '*':
@ -191,6 +191,8 @@ class Network(UUIDModel, TaggingMixin): # type: ignore
ipInt, version = net.ipToLong(ip)
return self.net_start <= ipInt <= self.net_end and self.version == version
__contains__ = contains
def save(self, *args, **kwargs) -> None:
"""
Overrides save to update the start, end and version fields
@ -201,21 +203,6 @@ class Network(UUIDModel, TaggingMixin): # type: ignore
self.version = rng.version
super().save(*args, **kwargs)
def update(self, name: str, netRange: str):
"""
Updated this network with provided values
Args:
name: new name of the network
netStart: new Network start (quad dotted)
netEnd: new Network end (quad dotted)
"""
self.name = name
self.net_string = netRange
self.save()
def __str__(self) -> str:
return u'Network {} ({}) from {} to {} ({})'.format(
self.name,

View File

@ -51,7 +51,7 @@ class UUIDModel(models.Model):
# Just a fake declaration to allow type checking
id: int
class Meta: # pylint: disable=too-few-public-methods
class Meta:
abstract = True
def genUuid(self) -> str:

View File

@ -152,7 +152,7 @@ class PhysicalMachinesProvider(services.ServiceProvider):
config = configparser.ConfigParser()
config.read_string(self.config.value)
for key in config['wol']:
if net.ipInNetwork(ip, key):
if net.contains(key, ip):
return config['wol'][key].replace('{MAC}', mac).replace('{IP}', ip)
except Exception as e: