1
0
mirror of https://github.com/dkmstr/openuds.git synced 2025-01-03 01:17:56 +03:00

Fixing typing also on tests

This commit is contained in:
Adolfo Gómez García 2024-02-26 05:50:36 +01:00
parent fc35c18fea
commit 3bf7af2e3d
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
32 changed files with 226 additions and 328 deletions

View File

@ -773,7 +773,7 @@ class gui:
return self.as_date() return self.as_date()
@value.setter @value.setter
def value(self, value: datetime.date) -> None: def value(self, value: datetime.date|str) -> None:
self._set_value(value) self._set_value(value)
def gui_description(self) -> dict[str, typing.Any]: def gui_description(self) -> dict[str, typing.Any]:

View File

@ -29,6 +29,7 @@
Author: Adolfo Gómez, dkmaster at dkmon dot com Author: Adolfo Gómez, dkmaster at dkmon dot com
""" """
import logging import logging
import typing
from uds.core import consts from uds.core import consts
from uds.core.consts.actor import UNMANAGED from uds.core.consts.actor import UNMANAGED
@ -59,25 +60,25 @@ class ActorTestTest(rest.test.RESTActorTestCase):
self.assertEqual(response.json()['error'], 'invalid token') self.assertEqual(response.json()['error'], 'invalid token')
# #
success = lambda: self.client.post( _success: typing.Callable[[], typing.Any] = lambda: self.client.post(
'/uds/rest/actor/v3/test', '/uds/rest/actor/v3/test',
data={'type': type_, 'token': token}, data={'type': type_, 'token': token},
content_type='application/json', content_type='application/json',
) )
invalid = lambda: self.client.post( _invalid: typing.Callable[[], typing.Any] = lambda: self.client.post(
'/uds/rest/actor/v3/test', '/uds/rest/actor/v3/test',
data={'type': type_, 'token': 'invalid'}, data={'type': type_, 'token': 'invalid'},
content_type='application/json', content_type='application/json',
) )
# Invalid actor token also fails # Invalid actor token also fails
response = invalid() response = _invalid()
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertEqual(response.json()['error'], 'invalid token') self.assertEqual(response.json()['error'], 'invalid token')
# This one works # This one works
response = success() response = _success()
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertEqual(response.json()['result'], 'ok') self.assertEqual(response.json()['result'], 'ok')
@ -85,20 +86,20 @@ class ActorTestTest(rest.test.RESTActorTestCase):
# And this one too, without authentication token # And this one too, without authentication token
# Without header, test will success because its not authenticated # Without header, test will success because its not authenticated
self.client.add_header(consts.auth.AUTH_TOKEN_HEADER, 'invalid') self.client.add_header(consts.auth.AUTH_TOKEN_HEADER, 'invalid')
response = success() response = _success()
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertEqual(response.json()['result'], 'ok') self.assertEqual(response.json()['result'], 'ok')
# We have ALLOWED_FAILS until we get blocked for a while # We have ALLOWED_FAILS until we get blocked for a while
# Next one will give 403 # Next one will give 403
for a in range(consts.system.ALLOWED_FAILS): for _i in range(consts.system.ALLOWED_FAILS):
response = invalid() response = _invalid()
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertEqual(response.json()['error'], 'invalid token') self.assertEqual(response.json()['error'], 'invalid token')
# And this one will give 403 # And this one will give 403
response = invalid() response = _invalid()
self.assertEqual(response.status_code, 403) self.assertEqual(response.status_code, 403)
def test_test_managed(self) -> None: def test_test_managed(self) -> None:
@ -108,7 +109,7 @@ class ActorTestTest(rest.test.RESTActorTestCase):
def test_test_unmanaged(self) -> None: def test_test_unmanaged(self) -> None:
# try for a first few services # try for a first few services
service = self.user_service_managed.deployed_service.service service = self.user_service_managed.deployed_service.service
actor_token = self.login_and_register() _actor_token = self.login_and_register()
# Get service token # Get service token
self.do_test(UNMANAGED, service.token or '') self.do_test(UNMANAGED, service.token or '')

View File

@ -113,7 +113,7 @@ class ServerRegisterTest(rest.test.RESTTestCase):
self._data2['mac'] = random_mac() self._data2['mac'] = random_mac()
self._data2['os'] = ( self._data2['os'] = (
types.os.KnownOS.UNKNOWN.value[0] types.os.KnownOS.UNKNOWN.value[0]
if os != types.os.KnownOS.UNKNOWN if os != types.os.KnownOS.UNKNOWN.value[0]
else types.os.KnownOS.WINDOWS.value[0] else types.os.KnownOS.WINDOWS.value[0]
) )
response = self.client.rest_post( response = self.client.rest_post(
@ -135,8 +135,6 @@ class ServerRegisterTest(rest.test.RESTTestCase):
# Rest of fields should be the same # Rest of fields should be the same
def test_invalid_register(self) -> None: def test_invalid_register(self) -> None:
response: 'UDSHttpResponse'
def _do_test(where: str) -> None: def _do_test(where: str) -> None:
response = self.client.rest_post( response = self.client.rest_post(
'servers/register', 'servers/register',

View File

@ -44,7 +44,7 @@ logger = logging.getLogger(__name__)
class SystemTest(rest.test.RESTTestCase): class SystemTest(rest.test.RESTTestCase):
def test_overview(self): def test_overview(self) -> None:
# If not logged in, will fail # If not logged in, will fail
response = self.client.rest_get('system/overview') response = self.client.rest_get('system/overview')
self.assertEqual(response.status_code, 403) self.assertEqual(response.status_code, 403)
@ -70,7 +70,7 @@ class SystemTest(rest.test.RESTTestCase):
self.assertEqual(json['restrained_services_pools'], 0) self.assertEqual(json['restrained_services_pools'], 0)
def test_chart_pool(self): def test_chart_pool(self) -> None:
# First, create fixtures for the pool # First, create fixtures for the pool
DAYS = 30 DAYS = 30
for pool in [self.user_service_managed, self.user_service_unmanaged]: for pool in [self.user_service_managed, self.user_service_unmanaged]:

View File

@ -71,7 +71,7 @@ SERIALIZED_AUTH_DATA: typing.Final[typing.Mapping[str, bytes]] = {
class SimpleLdapSerializationTest(UDSTestCase): class SimpleLdapSerializationTest(UDSTestCase):
def check_provider(self, version: str, instance: 'authenticator.SimpleLDAPAuthenticator'): def check_provider(self, version: str, instance: 'authenticator.SimpleLDAPAuthenticator') -> None:
self.assertEqual(instance.host.as_str(), 'host') self.assertEqual(instance.host.as_str(), 'host')
self.assertEqual(instance.port.as_int(), 166) self.assertEqual(instance.port.as_int(), 166)
self.assertEqual(instance.use_ssl.as_bool(), True) self.assertEqual(instance.use_ssl.as_bool(), True)
@ -91,14 +91,14 @@ class SimpleLdapSerializationTest(UDSTestCase):
self.assertEqual(instance.verify_ssl.as_bool(), True) self.assertEqual(instance.verify_ssl.as_bool(), True)
self.assertEqual(instance.certificate.as_str(), 'cert') self.assertEqual(instance.certificate.as_str(), 'cert')
def test_unmarshall_all_versions(self): def test_unmarshall_all_versions(self) -> None:
for v in range(1, len(SERIALIZED_AUTH_DATA) + 1): for v in range(1, len(SERIALIZED_AUTH_DATA) + 1):
with Environment.temporary_environment() as env: with Environment.temporary_environment() as env:
instance = authenticator.SimpleLDAPAuthenticator(environment=env) instance = authenticator.SimpleLDAPAuthenticator(environment=env)
instance.unmarshal(SERIALIZED_AUTH_DATA['v{}'.format(v)]) instance.unmarshal(SERIALIZED_AUTH_DATA['v{}'.format(v)])
self.check_provider(f'v{v}', instance) self.check_provider(f'v{v}', instance)
def test_marshaling(self): def test_marshaling(self) -> None:
# Unmarshall last version, remarshall and check that is marshalled using new marshalling format # Unmarshall last version, remarshall and check that is marshalled using new marshalling format
LAST_VERSION = 'v{}'.format(len(SERIALIZED_AUTH_DATA)) LAST_VERSION = 'v{}'.format(len(SERIALIZED_AUTH_DATA))
with Environment.temporary_environment() as env: with Environment.temporary_environment() as env:

View File

@ -33,6 +33,7 @@
import time import time
import threading import threading
from tracemalloc import stop from tracemalloc import stop
import typing
from unittest import mock from unittest import mock
from django.test import TransactionTestCase from django.test import TransactionTestCase
@ -57,24 +58,24 @@ class SchedulerTest(TransactionTestCase):
) as mock_ensure_jobs_registered: ) as mock_ensure_jobs_registered:
left = 4 left = 4
def our_execute_job(*args, **kwargs) -> None: def _our_execute_job(*args: typing.Any, **kwargs: typing.Any) -> None:
nonlocal left nonlocal left
left -= 1 left -= 1
if left == 0: if left == 0:
sch.notify_termination() sch.notify_termination()
mock_execute_job.side_effect = our_execute_job mock_execute_job.side_effect = _our_execute_job
# Execute run, but if it does not call execute_job, it will hang # Execute run, but if it does not call execute_job, it will hang
# so we execute a thread that will call notify_termination after 1 second # so we execute a thread that will call notify_termination after 1 second
stop_event = threading.Event() stop_event = threading.Event()
def ensure_stops() -> None: def _ensure_stops() -> None:
stop_event.wait(scheduler.Scheduler.granularity * 10) stop_event.wait(scheduler.Scheduler.granularity * 10)
if left > 0: if left > 0:
sch.notify_termination() sch.notify_termination()
watchdog = threading.Thread(target=ensure_stops) watchdog = threading.Thread(target=_ensure_stops)
watchdog.start() watchdog.start()
sch.run() sch.run()

View File

@ -50,7 +50,7 @@ class DownloadsManagerTest(WEBTestCase):
manager: DownloadsManager manager: DownloadsManager
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls) -> None:
from uds.core.managers import downloads_manager from uds.core.managers import downloads_manager
super().setUpClass() super().setUpClass()
@ -60,7 +60,7 @@ class DownloadsManagerTest(WEBTestCase):
) )
cls.manager = downloads_manager() cls.manager = downloads_manager()
def test_downloads(self): def test_downloads(self) -> None:
for v in ( for v in (
('test.txt', 'text/plain', '1f47ec0a-1ad4-5d63-b41c-5d2befadab8d'), ('test.txt', 'text/plain', '1f47ec0a-1ad4-5d63-b41c-5d2befadab8d'),
( (

View File

@ -132,9 +132,10 @@ class ModelAccountTest(UDSTestCase):
acc = models.Account.objects.create(name='Test Account') acc = models.Account.objects.create(name='Test Account')
for i in range(NUM_USERSERVICES): for i in range(NUM_USERSERVICES):
usage = acc.start_accounting(self.user_services[i]) usage = acc.start_accounting(self.user_services[i])
self.assertIsNotNone(usage) if not usage:
usage.start = usage.start - datetime.timedelta(seconds=32 + i) # type: ignore self.fail('Usage not created')
usage.save(update_fields=['start']) # type: ignore usage.start = usage.start - datetime.timedelta(seconds=32 + i)
usage.save(update_fields=['start'])
usage_end = acc.stop_accounting(self.user_services[i]) usage_end = acc.stop_accounting(self.user_services[i])
self.assertIsNotNone(usage_end) self.assertIsNotNone(usage_end)

View File

@ -51,7 +51,7 @@ class ModelUUIDTest(UDSTestCase):
self.group = authenticators_fixtures.create_groups(self.auth, 1)[0] self.group = authenticators_fixtures.create_groups(self.auth, 1)[0]
self.user = authenticators_fixtures.create_users(self.auth, 1, groups=[self.group])[0] self.user = authenticators_fixtures.create_users(self.auth, 1, groups=[self.group])[0]
def test_uuid_lowercase(self): def test_uuid_lowercase(self) -> None:
""" """
Tests that uuids are always lowercase Tests that uuids are always lowercase
""" """

View File

@ -47,7 +47,7 @@ class GraphsTest(UDSTestCase):
data1: dict[str, typing.Any] data1: dict[str, typing.Any]
data2: dict[str, typing.Any] data2: dict[str, typing.Any]
def setUp(self): def setUp(self) -> None:
# Data must be a dict with the following keys: # Data must be a dict with the following keys:
# - x: List of x values # - x: List of x values
# - y: List of dicts with the following keys: # - y: List of dicts with the following keys:
@ -109,7 +109,7 @@ class GraphsTest(UDSTestCase):
} }
def test_bar_chart(self): def test_bar_chart(self) -> None:
output = io.BytesIO() output = io.BytesIO()
graphs.bar_chart((10, 8, 96), data=self.data1, output=output) graphs.bar_chart((10, 8, 96), data=self.data1, output=output)
value = output.getvalue() value = output.getvalue()
@ -120,7 +120,7 @@ class GraphsTest(UDSTestCase):
f.write(value) f.write(value)
def test_line_chart(self): def test_line_chart(self) -> None:
output = io.BytesIO() output = io.BytesIO()
graphs.line_chart((10, 8, 96), data=self.data1, output=output) graphs.line_chart((10, 8, 96), data=self.data1, output=output)
value = output.getvalue() value = output.getvalue()
@ -131,7 +131,7 @@ class GraphsTest(UDSTestCase):
f.write(value) f.write(value)
def test_surface_chart(self): def test_surface_chart(self) -> None:
output = io.BytesIO() output = io.BytesIO()
graphs.surface_chart((10, 8, 96), data=self.data2, output=output) graphs.surface_chart((10, 8, 96), data=self.data2, output=output)
value = output.getvalue() value = output.getvalue()
@ -143,7 +143,7 @@ class GraphsTest(UDSTestCase):
f.write(value) f.write(value)
def test_surface_chart_wireframe(self): def test_surface_chart_wireframe(self) -> None:
self.data2['wireframe'] = True self.data2['wireframe'] = True
output = io.BytesIO() output = io.BytesIO()
graphs.surface_chart((10, 8, 96), data=self.data2, output=output) graphs.surface_chart((10, 8, 96), data=self.data2, output=output)

View File

@ -39,7 +39,7 @@ from django.conf import settings
from uds.core.util import ensure from uds.core.util import ensure
class GuiTest(UDSTestCase): class GuiTest(UDSTestCase):
def test_globals(self): def test_globals(self) -> None:
self.assertEqual(UDSK, settings.SECRET_KEY[8:24].encode()) self.assertEqual(UDSK, settings.SECRET_KEY[8:24].encode())
self.assertEqual(UDSB, b'udsprotect') self.assertEqual(UDSB, b'udsprotect')

View File

@ -48,7 +48,7 @@ logger = logging.getLogger(__name__)
class UserinterfaceInternalTest(UDSTestCase): class UserinterfaceInternalTest(UDSTestCase):
def test_value(self): def test_value(self) -> None:
# Asserts that data is correctly stored and retrieved # Asserts that data is correctly stored and retrieved
ui = TestingUserInterface() ui = TestingUserInterface()
self.assertEqual(ui.str_field.value, DEFAULTS['str_field']) self.assertEqual(ui.str_field.value, DEFAULTS['str_field'])
@ -88,7 +88,7 @@ class UserinterfaceInternalTest(UDSTestCase):
self.assertEqual(ui.help_field.value, DEFAULTS['help_field']) self.assertEqual(ui.help_field.value, DEFAULTS['help_field'])
def test_default(self): def test_default(self) -> None:
ui = TestingUserInterface() ui = TestingUserInterface()
# Now for default values # Now for default values
self.assertEqual(ui.str_field.default, DEFAULTS['str_field']) self.assertEqual(ui.str_field.default, DEFAULTS['str_field'])
@ -104,7 +104,7 @@ class UserinterfaceInternalTest(UDSTestCase):
self.assertEqual(ui.date_field.default, DEFAULTS['date_field']) self.assertEqual(ui.date_field.default, DEFAULTS['date_field'])
self.assertEqual(ui.help_field.default, DEFAULTS['help_field']) self.assertEqual(ui.help_field.default, DEFAULTS['help_field'])
def test_references(self): def test_references(self) -> None:
ui = TestingUserInterface() ui = TestingUserInterface()
# Ensure references are fine # Ensure references are fine
self.assertEqual(ui.str_field.value, ui._gui['str_field'].value) self.assertEqual(ui.str_field.value, ui._gui['str_field'].value)
@ -120,7 +120,7 @@ class UserinterfaceInternalTest(UDSTestCase):
self.assertEqual(ui.date_field.value, ui._gui['date_field'].value) self.assertEqual(ui.date_field.value, ui._gui['date_field'].value)
self.assertEqual(ui.help_field.value, ui._gui['help_field'].value) self.assertEqual(ui.help_field.value, ui._gui['help_field'].value)
def test_modify(self): def test_modify(self) -> None:
ui = TestingUserInterface() ui = TestingUserInterface()
# Modify values, and recheck references # Modify values, and recheck references
ui.str_field.value = 'New value' ui.str_field.value = 'New value'
@ -148,7 +148,7 @@ class UserinterfaceInternalTest(UDSTestCase):
ui.help_field.value = 'New value' ui.help_field.value = 'New value'
self.assertEqual(ui.help_field.value, ui._gui['help_field'].value) self.assertEqual(ui.help_field.value, ui._gui['help_field'].value)
def test_valuesDict(self): def test_valuesDict(self) -> None:
ui = TestingUserInterface() ui = TestingUserInterface()
self.assertEqual( self.assertEqual(
ui.get_fields_as_dict(), ui.get_fields_as_dict(),
@ -168,7 +168,7 @@ class UserinterfaceInternalTest(UDSTestCase):
}, },
) )
def test_labels(self): def test_labels(self) -> None:
ui = TestingUserInterface() ui = TestingUserInterface()
self.assertEqual( self.assertEqual(
{ k: v.label for k, v in ui._gui.items() }, { k: v.label for k, v in ui._gui.items() },
@ -188,7 +188,7 @@ class UserinterfaceInternalTest(UDSTestCase):
}, },
) )
def test_order(self): def test_order(self) -> None:
ui = TestingUserInterface() ui = TestingUserInterface()
self.assertEqual( self.assertEqual(
{ k: v._fields_info.order for k, v in ui._gui.items() }, { k: v._fields_info.order for k, v in ui._gui.items() },
@ -208,7 +208,7 @@ class UserinterfaceInternalTest(UDSTestCase):
}, },
) )
def test_required(self): def test_required(self) -> None:
ui = TestingUserInterface() ui = TestingUserInterface()
self.assertEqual( self.assertEqual(
{ k: v._fields_info.required for k, v in ui._gui.items() }, { k: v._fields_info.required for k, v in ui._gui.items() },
@ -228,7 +228,7 @@ class UserinterfaceInternalTest(UDSTestCase):
}, },
) )
def test_tooltip(self): def test_tooltip(self) -> None:
ui = TestingUserInterface() ui = TestingUserInterface()
self.assertEqual( self.assertEqual(
{ k: v._fields_info.tooltip for k, v in ui._gui.items() }, { k: v._fields_info.tooltip for k, v in ui._gui.items() },

View File

@ -39,7 +39,7 @@ import collections.abc
from ...utils.test import UDSTestCase from ...utils.test import UDSTestCase
from uds.core import types, consts from uds.core import types, consts
from uds.core.ui.user_interface import gui from uds.core.ui.user_interface import UserInterface
from unittest import mock from unittest import mock
from ...fixtures.user_interface import ( from ...fixtures.user_interface import (
@ -52,7 +52,7 @@ from ...fixtures.user_interface import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def old_serialize_form(ui) -> bytes: def old_serialize_form(ui: 'UserInterface') -> bytes:
""" """
All values stored at form fields are serialized and returned as a single All values stored at form fields are serialized and returned as a single
string string
@ -76,7 +76,7 @@ def old_serialize_form(ui) -> bytes:
# Separators for fields, old implementation # Separators for fields, old implementation
MULTIVALUE_FIELD: typing.Final[bytes] = b'\001' MULTIVALUE_FIELD: typing.Final[bytes] = b'\001'
OLD_PASSWORD_FIELD: typing.Final[bytes] = b'\004' # OLD_PASSWORD_FIELD: typing.Final[bytes] = b'\004'
PASSWORD_FIELD: typing.Final[bytes] = b'\005' PASSWORD_FIELD: typing.Final[bytes] = b'\005'
FIELD_SEPARATOR: typing.Final[bytes] = b'\002' FIELD_SEPARATOR: typing.Final[bytes] = b'\002'
@ -85,7 +85,7 @@ def old_serialize_form(ui) -> bytes:
# import inspect # import inspect
# logger.debug('Caller is : {}'.format(inspect.stack())) # logger.debug('Caller is : {}'.format(inspect.stack()))
arr = [] arr: list[bytes] = []
val: typing.Any val: typing.Any
for k, v in ui._gui.items(): for k, v in ui._gui.items():
logger.debug('serializing Key: %s/%s', k, v.value) logger.debug('serializing Key: %s/%s', k, v.value)
@ -147,7 +147,7 @@ class UserinterfaceTest(UDSTestCase):
) )
self.assertEqual(ui.date_field.value, DEFAULTS['date_field'], 'date_field') self.assertEqual(ui.date_field.value, DEFAULTS['date_field'], 'date_field')
def test_old_serialization(self): def test_old_serialization(self) -> None:
# This test is to ensure that old serialized data can be loaded # This test is to ensure that old serialized data can be loaded
# This data is from a # This data is from a
ui = TestingUserInterface() ui = TestingUserInterface()
@ -165,7 +165,7 @@ class UserinterfaceTest(UDSTestCase):
self.assertEqual(ui, ui3) self.assertEqual(ui, ui3)
self.ensure_values_fine(ui3) self.ensure_values_fine(ui3)
def test_new_serialization(self): def test_new_serialization(self) -> None:
# This test is to ensure that new serialized data can be loaded # This test is to ensure that new serialized data can be loaded
# First # First
ui = TestingUserInterface() ui = TestingUserInterface()
@ -176,7 +176,7 @@ class UserinterfaceTest(UDSTestCase):
self.assertEqual(ui, ui2) self.assertEqual(ui, ui2)
self.ensure_values_fine(ui2) self.ensure_values_fine(ui2)
def test_old_field_name(self): def test_old_field_name(self) -> None:
# This test is to ensure that new serialized data can be loaded # This test is to ensure that new serialized data can be loaded
# mock logging warning # mock logging warning
ui = TestingUserInterfaceFieldNameOrig() ui = TestingUserInterfaceFieldNameOrig()

View File

@ -185,7 +185,7 @@ class AutoSerializable(UDSTestCase):
def test_auto_serializable_base_encrypted(self) -> None: def test_auto_serializable_base_encrypted(self) -> None:
self.basic_check(AutoSerializableClass, AutoSerializableEncryptedClass) self.basic_check(AutoSerializableClass, AutoSerializableEncryptedClass)
def test_auto_serializable_derived(self): def test_auto_serializable_derived(self) -> None:
instance = DerivedAutoSerializableClass() instance = DerivedAutoSerializableClass()
instance.int_field = 1 instance.int_field = 1
instance.str_field = UNICODE_CHARS instance.str_field = UNICODE_CHARS
@ -204,7 +204,7 @@ class AutoSerializable(UDSTestCase):
self.assertEqual(instance, instance2) self.assertEqual(instance, instance2)
def test_auto_serializable_derived_added(self): def test_auto_serializable_derived_added(self) -> None:
instance = DerivedAutoSerializableClass2() instance = DerivedAutoSerializableClass2()
instance.int_field = 1 instance.int_field = 1
instance.str_field = UNICODE_CHARS instance.str_field = UNICODE_CHARS
@ -225,7 +225,7 @@ class AutoSerializable(UDSTestCase):
self.assertEqual(instance, instance2) self.assertEqual(instance, instance2)
def test_auto_serializable_with_missing_fields(self): def test_auto_serializable_with_missing_fields(self) -> None:
instance = AutoSerializableClass() instance = AutoSerializableClass()
instance.int_field = 1 instance.int_field = 1
instance.str_field = UNICODE_CHARS instance.str_field = UNICODE_CHARS
@ -250,7 +250,7 @@ class AutoSerializable(UDSTestCase):
self.assertEqual(instance2.list_field, [1, 2, 3]) self.assertEqual(instance2.list_field, [1, 2, 3])
self.assertEqual(instance2.obj_nt_field, SerializableNamedTuple(2, '3', 4.0)) self.assertEqual(instance2.obj_nt_field, SerializableNamedTuple(2, '3', 4.0))
def test_auto_serializable_with_added_fields(self): def test_auto_serializable_with_added_fields(self) -> None:
instance = AutoSerializableClassWithMissingFields() instance = AutoSerializableClassWithMissingFields()
instance.int_field = 1 instance.int_field = 1
instance.bool_field = True instance.bool_field = True

View File

@ -44,7 +44,7 @@ VALUE_1 = [u'únîcödè€', b'string', {'a': 1, 'b': 2.0}]
class CacheTest(UDSTransactionTestCase): class CacheTest(UDSTransactionTestCase):
def test_cache(self): def test_cache(self) -> None:
cache = Cache(UNICODE_CHARS) cache = Cache(UNICODE_CHARS)
# Get default value, with unicode # Get default value, with unicode

View File

@ -42,7 +42,7 @@ class CalendarTest(UDSTestCase):
def setUp(self) -> None: def setUp(self) -> None:
createCalendars() createCalendars()
def test_calendar_dayly(self): def test_calendar_dayly(self) -> None:
cal = Calendar.objects.get(uuid='2cf6846b-d889-57ce-bb35-e647040a95b6') cal = Calendar.objects.get(uuid='2cf6846b-d889-57ce-bb35-e647040a95b6')
chk = calendar.CalendarChecker(cal) chk = calendar.CalendarChecker(cal)
calendar.CalendarChecker.updates = 0 calendar.CalendarChecker.updates = 0
@ -94,7 +94,7 @@ class CalendarTest(UDSTestCase):
self.assertEqual(chk.updates, 366) self.assertEqual(chk.updates, 366)
def test_calendar_weekly(self): def test_calendar_weekly(self) -> None:
cal = Calendar.objects.get(uuid='c1221a6d-3848-5fa3-ae98-172662c0f554') cal = Calendar.objects.get(uuid='c1221a6d-3848-5fa3-ae98-172662c0f554')
chk = calendar.CalendarChecker(cal) chk = calendar.CalendarChecker(cal)
calendar.CalendarChecker.updates = 0 calendar.CalendarChecker.updates = 0
@ -144,7 +144,7 @@ class CalendarTest(UDSTestCase):
self.assertEqual(chk.updates, 365) self.assertEqual(chk.updates, 365)
def test_calendar_monthly(self): def test_calendar_monthly(self) -> None:
cal = Calendar.objects.get(uuid='353c4cb8-e02d-5387-a18f-f634729fde81') cal = Calendar.objects.get(uuid='353c4cb8-e02d-5387-a18f-f634729fde81')
chk = calendar.CalendarChecker(cal) chk = calendar.CalendarChecker(cal)
calendar.CalendarChecker.updates = 0 calendar.CalendarChecker.updates = 0
@ -186,7 +186,7 @@ class CalendarTest(UDSTestCase):
self.assertEqual(chk.updates, 730) self.assertEqual(chk.updates, 730)
def test_calendar_weekdays(self): def test_calendar_weekdays(self) -> None:
cal = Calendar.objects.get(uuid='bccfd011-605b-565f-a08e-80bf75114dce') cal = Calendar.objects.get(uuid='bccfd011-605b-565f-a08e-80bf75114dce')
chk = calendar.CalendarChecker(cal) chk = calendar.CalendarChecker(cal)
calendar.CalendarChecker.updates = 0 calendar.CalendarChecker.updates = 0
@ -229,7 +229,7 @@ class CalendarTest(UDSTestCase):
self.assertEqual(chk.updates, 730) self.assertEqual(chk.updates, 730)
def test_calendar_durations(self): def test_calendar_durations(self) -> None:
cal = Calendar.objects.get(uuid='60160f94-c8fe-5fdc-bbbe-325010980106') cal = Calendar.objects.get(uuid='60160f94-c8fe-5fdc-bbbe-325010980106')
chk = calendar.CalendarChecker(cal) chk = calendar.CalendarChecker(cal)

View File

@ -35,6 +35,8 @@ import typing
import collections.abc import collections.abc
import uds.core.types.permissions import uds.core.types.permissions
from django.db.models import Model
from uds.core.util import permissions from uds.core.util import permissions
from uds.core.util import objtype from uds.core.util import objtype
from uds import models from uds import models
@ -62,9 +64,7 @@ class PermissionsTest(UDSTestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.authenticator = authenticators_fixtures.create_authenticator() self.authenticator = authenticators_fixtures.create_authenticator()
self.groups = authenticators_fixtures.create_groups(self.authenticator) self.groups = authenticators_fixtures.create_groups(self.authenticator)
self.users = authenticators_fixtures.create_users( self.users = authenticators_fixtures.create_users(self.authenticator, groups=self.groups)
self.authenticator, groups=self.groups
)
self.admins = authenticators_fixtures.create_users( self.admins = authenticators_fixtures.create_users(
self.authenticator, is_admin=True, groups=self.groups self.authenticator, is_admin=True, groups=self.groups
) )
@ -83,98 +83,54 @@ class PermissionsTest(UDSTestCase):
self.network = network_fixtures.createNetwork() self.network = network_fixtures.createNetwork()
def doTestUserPermissions(self, obj, user: models.User): def doTestUserPermissions(self, obj: 'Model', user: models.User) -> None:
permissions.add_user_permission( permissions.add_user_permission(user, obj, uds.core.types.permissions.PermissionType.NONE)
user, obj, uds.core.types.permissions.PermissionType.NONE
)
self.assertEqual(models.Permissions.objects.count(), 1) self.assertEqual(models.Permissions.objects.count(), 1)
perm = models.Permissions.objects.all()[0] perm = models.Permissions.objects.all()[0]
self.assertEqual(perm.object_type, objtype.ObjectType.from_model(obj).type) self.assertEqual(perm.object_type, objtype.ObjectType.from_model(obj).type)
self.assertEqual(perm.object_id, obj.pk) self.assertEqual(perm.object_id, obj.pk)
self.assertEqual(perm.permission, uds.core.types.permissions.PermissionType.NONE) self.assertEqual(perm.permission, uds.core.types.permissions.PermissionType.NONE)
self.assertTrue( self.assertTrue(permissions.has_access(user, obj, uds.core.types.permissions.PermissionType.NONE))
permissions.has_access(
user, obj, uds.core.types.permissions.PermissionType.NONE
)
)
self.assertEqual( self.assertEqual(
permissions.has_access( permissions.has_access(user, obj, uds.core.types.permissions.PermissionType.READ),
user, obj, uds.core.types.permissions.PermissionType.READ
),
user.is_admin, user.is_admin,
) )
self.assertEqual( self.assertEqual(
permissions.has_access( permissions.has_access(user, obj, uds.core.types.permissions.PermissionType.MANAGEMENT),
user, obj, uds.core.types.permissions.PermissionType.MANAGEMENT
),
user.is_admin, user.is_admin,
) )
self.assertEqual( self.assertEqual(
permissions.has_access( permissions.has_access(user, obj, uds.core.types.permissions.PermissionType.ALL),
user, obj, uds.core.types.permissions.PermissionType.ALL
),
user.is_admin, user.is_admin,
) )
# Add a new permission, must overwrite the previous one # Add a new permission, must overwrite the previous one
permissions.add_user_permission( permissions.add_user_permission(user, obj, uds.core.types.permissions.PermissionType.ALL)
user, obj, uds.core.types.permissions.PermissionType.ALL
)
self.assertEqual(models.Permissions.objects.count(), 1) self.assertEqual(models.Permissions.objects.count(), 1)
perm = models.Permissions.objects.all()[0] perm = models.Permissions.objects.all()[0]
self.assertEqual(perm.object_type, PermissionsTest.getObjectType(obj)) self.assertEqual(perm.object_type, PermissionsTest.getObjectType(obj))
self.assertEqual(perm.object_id, obj.pk) self.assertEqual(perm.object_id, obj.pk)
self.assertEqual(perm.permission, uds.core.types.permissions.PermissionType.ALL) self.assertEqual(perm.permission, uds.core.types.permissions.PermissionType.ALL)
self.assertTrue( self.assertTrue(permissions.has_access(user, obj, uds.core.types.permissions.PermissionType.NONE))
permissions.has_access( self.assertTrue(permissions.has_access(user, obj, uds.core.types.permissions.PermissionType.READ))
user, obj, uds.core.types.permissions.PermissionType.NONE self.assertTrue(permissions.has_access(user, obj, uds.core.types.permissions.PermissionType.MANAGEMENT))
) self.assertTrue(permissions.has_access(user, obj, uds.core.types.permissions.PermissionType.ALL))
)
self.assertTrue(
permissions.has_access(
user, obj, uds.core.types.permissions.PermissionType.READ
)
)
self.assertTrue(
permissions.has_access(
user, obj, uds.core.types.permissions.PermissionType.MANAGEMENT
)
)
self.assertTrue(
permissions.has_access(
user, obj, uds.core.types.permissions.PermissionType.ALL
)
)
# Again, with read # Again, with read
permissions.add_user_permission( permissions.add_user_permission(user, obj, uds.core.types.permissions.PermissionType.READ)
user, obj, uds.core.types.permissions.PermissionType.READ
)
self.assertEqual(models.Permissions.objects.count(), 1) self.assertEqual(models.Permissions.objects.count(), 1)
perm = models.Permissions.objects.all()[0] perm = models.Permissions.objects.all()[0]
self.assertEqual(perm.object_type, PermissionsTest.getObjectType(obj)) self.assertEqual(perm.object_type, PermissionsTest.getObjectType(obj))
self.assertEqual(perm.object_id, obj.pk) self.assertEqual(perm.object_id, obj.pk)
self.assertEqual(perm.permission, uds.core.types.permissions.PermissionType.READ) self.assertEqual(perm.permission, uds.core.types.permissions.PermissionType.READ)
self.assertTrue( self.assertTrue(permissions.has_access(user, obj, uds.core.types.permissions.PermissionType.NONE))
permissions.has_access( self.assertTrue(permissions.has_access(user, obj, uds.core.types.permissions.PermissionType.READ))
user, obj, uds.core.types.permissions.PermissionType.NONE
)
)
self.assertTrue(
permissions.has_access(
user, obj, uds.core.types.permissions.PermissionType.READ
)
)
self.assertEqual( self.assertEqual(
permissions.has_access( permissions.has_access(user, obj, uds.core.types.permissions.PermissionType.MANAGEMENT),
user, obj, uds.core.types.permissions.PermissionType.MANAGEMENT
),
user.is_admin, user.is_admin,
) )
self.assertEqual( self.assertEqual(
permissions.has_access( permissions.has_access(user, obj, uds.core.types.permissions.PermissionType.ALL),
user, obj, uds.core.types.permissions.PermissionType.ALL
),
user.is_admin, user.is_admin,
) )
@ -182,85 +138,47 @@ class PermissionsTest(UDSTestCase):
obj.delete() obj.delete()
self.assertEqual(models.Permissions.objects.count(), 0) self.assertEqual(models.Permissions.objects.count(), 0)
def doTestGroupPermissions(self, obj, user: models.User): def doTestGroupPermissions(self, obj: 'Model', user: models.User):
group = user.groups.all()[0] group = user.groups.all()[0]
permissions.add_group_permission( permissions.add_group_permission(group, obj, uds.core.types.permissions.PermissionType.NONE)
group, obj, uds.core.types.permissions.PermissionType.NONE
)
self.assertEqual(models.Permissions.objects.count(), 1) self.assertEqual(models.Permissions.objects.count(), 1)
perm = models.Permissions.objects.all()[0] perm = models.Permissions.objects.all()[0]
self.assertEqual(perm.object_type, PermissionsTest.getObjectType(obj)) self.assertEqual(perm.object_type, PermissionsTest.getObjectType(obj))
self.assertEqual(perm.object_id, obj.pk) self.assertEqual(perm.object_id, obj.pk)
self.assertEqual(perm.permission, uds.core.types.permissions.PermissionType.NONE) self.assertEqual(perm.permission, uds.core.types.permissions.PermissionType.NONE)
self.assertTrue( self.assertTrue(permissions.has_access(user, obj, uds.core.types.permissions.PermissionType.NONE))
permissions.has_access(
user, obj, uds.core.types.permissions.PermissionType.NONE
)
)
# Admins has all permissions ALWAYS # Admins has all permissions ALWAYS
self.assertEqual( self.assertEqual(
permissions.has_access( permissions.has_access(user, obj, uds.core.types.permissions.PermissionType.READ),
user, obj, uds.core.types.permissions.PermissionType.READ
),
user.is_admin, user.is_admin,
) )
self.assertEqual( self.assertEqual(
permissions.has_access( permissions.has_access(user, obj, uds.core.types.permissions.PermissionType.ALL),
user, obj, uds.core.types.permissions.PermissionType.ALL
),
user.is_admin, user.is_admin,
) )
permissions.add_group_permission( permissions.add_group_permission(group, obj, uds.core.types.permissions.PermissionType.ALL)
group, obj, uds.core.types.permissions.PermissionType.ALL
)
self.assertEqual(models.Permissions.objects.count(), 1) self.assertEqual(models.Permissions.objects.count(), 1)
perm = models.Permissions.objects.all()[0] perm = models.Permissions.objects.all()[0]
self.assertEqual(perm.object_type, PermissionsTest.getObjectType(obj)) self.assertEqual(perm.object_type, PermissionsTest.getObjectType(obj))
self.assertEqual(perm.object_id, obj.pk) self.assertEqual(perm.object_id, obj.pk)
self.assertEqual(perm.permission, uds.core.types.permissions.PermissionType.ALL) self.assertEqual(perm.permission, uds.core.types.permissions.PermissionType.ALL)
self.assertTrue( self.assertTrue(permissions.has_access(user, obj, uds.core.types.permissions.PermissionType.NONE))
permissions.has_access( self.assertTrue(permissions.has_access(user, obj, uds.core.types.permissions.PermissionType.READ))
user, obj, uds.core.types.permissions.PermissionType.NONE self.assertTrue(permissions.has_access(user, obj, uds.core.types.permissions.PermissionType.ALL))
)
)
self.assertTrue(
permissions.has_access(
user, obj, uds.core.types.permissions.PermissionType.READ
)
)
self.assertTrue(
permissions.has_access(
user, obj, uds.core.types.permissions.PermissionType.ALL
)
)
# Add user permission, DB must contain both an return ALL # Add user permission, DB must contain both an return ALL
permissions.add_user_permission( permissions.add_user_permission(user, obj, uds.core.types.permissions.PermissionType.READ)
user, obj, uds.core.types.permissions.PermissionType.READ
)
self.assertEqual(models.Permissions.objects.count(), 2) self.assertEqual(models.Permissions.objects.count(), 2)
perm = models.Permissions.objects.all()[0] perm = models.Permissions.objects.all()[0]
self.assertEqual(perm.object_type, PermissionsTest.getObjectType(obj)) self.assertEqual(perm.object_type, PermissionsTest.getObjectType(obj))
self.assertEqual(perm.object_id, obj.pk) self.assertEqual(perm.object_id, obj.pk)
self.assertEqual(perm.permission, uds.core.types.permissions.PermissionType.ALL) self.assertEqual(perm.permission, uds.core.types.permissions.PermissionType.ALL)
self.assertTrue( self.assertTrue(permissions.has_access(user, obj, uds.core.types.permissions.PermissionType.NONE))
permissions.has_access( self.assertTrue(permissions.has_access(user, obj, uds.core.types.permissions.PermissionType.READ))
user, obj, uds.core.types.permissions.PermissionType.NONE self.assertTrue(permissions.has_access(user, obj, uds.core.types.permissions.PermissionType.ALL))
)
)
self.assertTrue(
permissions.has_access(
user, obj, uds.core.types.permissions.PermissionType.READ
)
)
self.assertTrue(
permissions.has_access(
user, obj, uds.core.types.permissions.PermissionType.ALL
)
)
# Remove obj, permissions must have gone away # Remove obj, permissions must have gone away
obj.delete() obj.delete()
@ -268,150 +186,126 @@ class PermissionsTest(UDSTestCase):
# Every tests reverses the DB and recalls setUp # Every tests reverses the DB and recalls setUp
def test_user_auth_permissions_user(self): def test_user_auth_permissions_user(self) -> None:
self.doTestUserPermissions(self.authenticator, self.users[0]) self.doTestUserPermissions(self.authenticator, self.users[0])
def test_user_auth_permissions_admin(self): def test_user_auth_permissions_admin(self) -> None:
self.doTestUserPermissions(self.authenticator, self.admins[0]) self.doTestUserPermissions(self.authenticator, self.admins[0])
def test_user_auth_permissions_staff(self): def test_user_auth_permissions_staff(self) -> None:
self.doTestUserPermissions(self.authenticator, self.staffs[0]) self.doTestUserPermissions(self.authenticator, self.staffs[0])
def test_group_auth_permissions_user(self): def test_group_auth_permissions_user(self) -> None:
self.doTestGroupPermissions(self.authenticator, self.users[0]) self.doTestGroupPermissions(self.authenticator, self.users[0])
def test_group_auth_permissions_admin(self): def test_group_auth_permissions_admin(self) -> None:
self.doTestGroupPermissions(self.authenticator, self.admins[0]) self.doTestGroupPermissions(self.authenticator, self.admins[0])
def test_group_auth_permissions_staff(self): def test_group_auth_permissions_staff(self) -> None:
self.doTestGroupPermissions(self.authenticator, self.staffs[0]) self.doTestGroupPermissions(self.authenticator, self.staffs[0])
def test_user_servicepool_permissions_user(self): def test_user_servicepool_permissions_user(self) -> None:
self.doTestUserPermissions(self.userService.deployed_service, self.users[0]) self.doTestUserPermissions(self.userService.deployed_service, self.users[0])
def test_user_servicepool_permissions_admin(self): def test_user_servicepool_permissions_admin(self) -> None:
self.doTestUserPermissions(self.userService.deployed_service, self.admins[0]) self.doTestUserPermissions(self.userService.deployed_service, self.admins[0])
def test_user_servicepool_permissions_staff(self): def test_user_servicepool_permissions_staff(self) -> None:
self.doTestUserPermissions(self.userService.deployed_service, self.staffs[0]) self.doTestUserPermissions(self.userService.deployed_service, self.staffs[0])
def test_group_servicepool_permissions_user(self): def test_group_servicepool_permissions_user(self) -> None:
self.doTestGroupPermissions(self.userService.deployed_service, self.users[0]) self.doTestGroupPermissions(self.userService.deployed_service, self.users[0])
def test_group_servicepool_permissions_admin(self): def test_group_servicepool_permissions_admin(self) -> None:
self.doTestGroupPermissions(self.userService.deployed_service, self.admins[0]) self.doTestGroupPermissions(self.userService.deployed_service, self.admins[0])
def test_group_servicepool_permissions_staff(self): def test_group_servicepool_permissions_staff(self) -> None:
self.doTestGroupPermissions(self.userService.deployed_service, self.staffs[0]) self.doTestGroupPermissions(self.userService.deployed_service, self.staffs[0])
def test_user_transport_permissions_user(self): def test_user_transport_permissions_user(self) -> None:
self.doTestUserPermissions( self.doTestUserPermissions(
self.userService.deployed_service.transports.first(), self.users[0] typing.cast(models.Transport, self.userService.deployed_service.transports.first()), self.users[0]
) )
def test_user_transport_permissions_admin(self): def test_user_transport_permissions_admin(self) -> None:
self.doTestUserPermissions( self.doTestUserPermissions(
self.userService.deployed_service.transports.first(), self.admins[0] typing.cast(models.Transport, self.userService.deployed_service.transports.first()), self.admins[0]
) )
def test_user_transport_permissions_staff(self): def test_user_transport_permissions_staff(self) -> None:
self.doTestUserPermissions( self.doTestUserPermissions(
self.userService.deployed_service.transports.first(), self.staffs[0] typing.cast(models.Transport, self.userService.deployed_service.transports.first()), self.staffs[0]
) )
def test_group_transport_permissions_user(self): def test_group_transport_permissions_user(self) -> None:
self.doTestGroupPermissions( self.doTestGroupPermissions(
self.userService.deployed_service.transports.first(), self.users[0] typing.cast(models.Transport, self.userService.deployed_service.transports.first()), self.users[0]
) )
def test_group_transport_permissions_admin(self): def test_group_transport_permissions_admin(self) -> None:
self.doTestGroupPermissions( self.doTestGroupPermissions(
self.userService.deployed_service.transports.first(), self.admins[0] typing.cast(models.Transport, self.userService.deployed_service.transports.first()), self.admins[0]
) )
def test_group_transport_permissions_staff(self): def test_group_transport_permissions_staff(self) -> None:
self.doTestGroupPermissions( self.doTestGroupPermissions(
self.userService.deployed_service.transports.first(), self.staffs[0] typing.cast(models.Transport, self.userService.deployed_service.transports.first()), self.staffs[0]
) )
def test_user_service_permissions_user(self): def test_user_service_permissions_user(self) -> None:
self.doTestUserPermissions( self.doTestUserPermissions(self.userService.deployed_service.service, self.users[0])
self.userService.deployed_service.service, self.users[0]
)
def test_user_service_permissions_admin(self): def test_user_service_permissions_admin(self) -> None:
self.doTestUserPermissions( self.doTestUserPermissions(self.userService.deployed_service.service, self.admins[0])
self.userService.deployed_service.service, self.admins[0]
)
def test_user_service_permissions_staff(self): def test_user_service_permissions_staff(self) -> None:
self.doTestUserPermissions( self.doTestUserPermissions(self.userService.deployed_service.service, self.staffs[0])
self.userService.deployed_service.service, self.staffs[0]
)
def test_group_service_permissions_user(self): def test_group_service_permissions_user(self) -> None:
self.doTestGroupPermissions( self.doTestGroupPermissions(self.userService.deployed_service.service, self.users[0])
self.userService.deployed_service.service, self.users[0]
)
def test_group_service_permissions_admin(self): def test_group_service_permissions_admin(self) -> None:
self.doTestGroupPermissions( self.doTestGroupPermissions(self.userService.deployed_service.service, self.admins[0])
self.userService.deployed_service.service, self.admins[0]
)
def test_group_service_permissions_staff(self): def test_group_service_permissions_staff(self) -> None:
self.doTestGroupPermissions( self.doTestGroupPermissions(self.userService.deployed_service.service, self.staffs[0])
self.userService.deployed_service.service, self.staffs[0]
)
def test_user_provider_permissions_user(self): def test_user_provider_permissions_user(self) -> None:
self.doTestUserPermissions( self.doTestUserPermissions(self.userService.deployed_service.service.provider, self.users[0])
self.userService.deployed_service.service.provider, self.users[0]
)
def test_user_provider_permissions_admin(self): def test_user_provider_permissions_admin(self) -> None:
self.doTestUserPermissions( self.doTestUserPermissions(self.userService.deployed_service.service.provider, self.admins[0])
self.userService.deployed_service.service.provider, self.admins[0]
)
def test_user_provider_permissions_staff(self): def test_user_provider_permissions_staff(self) -> None:
self.doTestUserPermissions( self.doTestUserPermissions(self.userService.deployed_service.service.provider, self.staffs[0])
self.userService.deployed_service.service.provider, self.staffs[0]
)
def test_group_provider_permissions_user(self): def test_group_provider_permissions_user(self) -> None:
self.doTestGroupPermissions( self.doTestGroupPermissions(self.userService.deployed_service.service.provider, self.users[0])
self.userService.deployed_service.service.provider, self.users[0]
)
def test_group_provider_permissions_admin(self): def test_group_provider_permissions_admin(self) -> None:
self.doTestGroupPermissions( self.doTestGroupPermissions(self.userService.deployed_service.service.provider, self.admins[0])
self.userService.deployed_service.service.provider, self.admins[0]
)
def test_group_provider_permissions_staff(self): def test_group_provider_permissions_staff(self) -> None:
self.doTestGroupPermissions( self.doTestGroupPermissions(self.userService.deployed_service.service.provider, self.staffs[0])
self.userService.deployed_service.service.provider, self.staffs[0]
)
def test_user_network_permissions_user(self): def test_user_network_permissions_user(self) -> None:
self.doTestUserPermissions(self.network, self.users[0]) self.doTestUserPermissions(self.network, self.users[0])
def test_user_network_permissions_admin(self): def test_user_network_permissions_admin(self) -> None:
self.doTestUserPermissions(self.network, self.admins[0]) self.doTestUserPermissions(self.network, self.admins[0])
def test_user_network_permissions_staff(self): def test_user_network_permissions_staff(self) -> None:
self.doTestUserPermissions(self.network, self.staffs[0]) self.doTestUserPermissions(self.network, self.staffs[0])
def test_group_network_permissions_user(self): def test_group_network_permissions_user(self) -> None:
self.doTestGroupPermissions(self.network, self.users[0]) self.doTestGroupPermissions(self.network, self.users[0])
def test_group_network_permissions_admin(self): def test_group_network_permissions_admin(self) -> None:
self.doTestGroupPermissions(self.network, self.admins[0]) self.doTestGroupPermissions(self.network, self.admins[0])
def test_group_network_permissions_staff(self): def test_group_network_permissions_staff(self) -> None:
self.doTestGroupPermissions(self.network, self.staffs[0]) self.doTestGroupPermissions(self.network, self.staffs[0])
@staticmethod @staticmethod
def getObjectType(obj: typing.Type) -> int: def getObjectType(obj: typing.Any) -> int:
return objtype.ObjectType.from_model(obj).type return objtype.ObjectType.from_model(obj).type

View File

@ -41,7 +41,7 @@ VALUE_1 = ['unicode', b'string', {'a': 1, 'b': 2.0}]
class StorageTest(UDSTestCase): class StorageTest(UDSTestCase):
def test_storage(self): def test_storage(self) -> None:
strg = storage.Storage(UNICODE_CHARS) strg = storage.Storage(UNICODE_CHARS)
strg.put(UNICODE_CHARS, b'chars') strg.put(UNICODE_CHARS, b'chars')
@ -72,7 +72,7 @@ class StorageTest(UDSTestCase):
self.assertIsNone(strg.get(b'key')) self.assertIsNone(strg.get(b'key'))
self.assertIsNone(strg.get_unpickle('pickle')) self.assertIsNone(strg.get_unpickle('pickle'))
def test_storage_as_dict(self): def test_storage_as_dict(self) -> None:
strg = storage.Storage(UNICODE_CHARS) strg = storage.Storage(UNICODE_CHARS)
strg.put(UNICODE_CHARS, 'chars') strg.put(UNICODE_CHARS, 'chars')
@ -89,7 +89,7 @@ class StorageTest(UDSTestCase):
# because the format is not compatible (with the dict, the values are stored as a tuple, with the original key stored # because the format is not compatible (with the dict, the values are stored as a tuple, with the original key stored
# and with old format, only the value is stored # and with old format, only the value is stored
def test_old_storage_compat(self): def test_old_storage_compat(self) -> None:
models.Storage.objects.create( models.Storage.objects.create(
owner=UNICODE_CHARS, owner=UNICODE_CHARS,
key=storage._old_calculate_key(UNICODE_CHARS.encode(), UNICODE_CHARS.encode()), key=storage._old_calculate_key(UNICODE_CHARS.encode(), UNICODE_CHARS.encode()),
@ -105,7 +105,7 @@ class StorageTest(UDSTestCase):
key=storage._calculate_key(UNICODE_CHARS.encode(), UNICODE_CHARS.encode()), key=storage._calculate_key(UNICODE_CHARS.encode(), UNICODE_CHARS.encode()),
) )
def test_storage_as_dict_old(self): def test_storage_as_dict_old(self) -> None:
models.Storage.objects.create( models.Storage.objects.create(
owner=UNICODE_CHARS, owner=UNICODE_CHARS,
key=storage._old_calculate_key(UNICODE_CHARS.encode(), UNICODE_CHARS.encode()), key=storage._old_calculate_key(UNICODE_CHARS.encode(), UNICODE_CHARS.encode()),

View File

@ -49,81 +49,81 @@ TEST_MAC_RANGE = '00:50:56:10:00:00-00:50:56:3F:FF:FF' # Testing mac range
TEST_MAC_RANGE_FULL = '00:50:56:10:00:00-00:50:56:10:00:10' # Testing mac range for NO MORE MACS TEST_MAC_RANGE_FULL = '00:50:56:10:00:00-00:50:56:10:00:10' # Testing mac range for NO MORE MACS
def macToInt(mac): def mac_to_integer(mac: str) -> int:
return int(mac.replace(':', ''), 16) return int(mac.replace(':', ''), 16)
class UniqueIdTest(UDSTestCase): class UniqueIdTest(UDSTestCase):
uidGen: UniqueIDGenerator uniqueid_generator: UniqueIDGenerator
ugidGen: UniqueGIDGenerator ugidGen: UniqueGIDGenerator
macGen: UniqueMacGenerator macGen: UniqueMacGenerator
nameGen: UniqueNameGenerator name_generator: UniqueNameGenerator
def setUp(self) -> None: def setUp(self) -> None:
self.uidGen = UniqueIDGenerator('uidg1', 'test', 'test') self.uniqueid_generator = UniqueIDGenerator('uidg1', 'test', 'test')
self.ugidGen = UniqueGIDGenerator('test') self.ugidGen = UniqueGIDGenerator('test')
self.macGen = UniqueMacGenerator('test') self.macGen = UniqueMacGenerator('test')
self.nameGen = UniqueNameGenerator('test') self.name_generator = UniqueNameGenerator('test')
def test_seq_uid(self): def test_seq_uid(self) -> None:
for x in range(100): for x in range(100):
self.assertEqual(self.uidGen.get(), x) self.assertEqual(self.uniqueid_generator.get(), x)
self.uidGen.free(40) self.uniqueid_generator.free(40)
self.assertEqual(self.uidGen.get(), 40) self.assertEqual(self.uniqueid_generator.get(), 40)
def test_release_unique_id(self): def test_release_unique_id(self) -> None:
for _ in range(100): for _ in range(100):
self.uidGen.get() self.uniqueid_generator.get()
self.assertEqual(self.uidGen.get(), 100) self.assertEqual(self.uniqueid_generator.get(), 100)
self.uidGen.release() # Clear ups self.uniqueid_generator.release() # Clear ups
self.assertEqual(self.uidGen.get(), 0) self.assertEqual(self.uniqueid_generator.get(), 0)
def test_release_older_unique_id(self): def test_release_older_unique_id(self) -> None:
NUM = 100 NUM = 100
for i in range(NUM): for i in range(NUM):
self.assertEqual(self.uidGen.get(), i) self.assertEqual(self.uniqueid_generator.get(), i)
stamp = sql_stamp_seconds() + 1 stamp = sql_stamp_seconds() + 1
time.sleep(2) time.sleep(2)
for i in range(NUM): for i in range(NUM):
self.assertEqual(self.uidGen.get(), i + NUM) self.assertEqual(self.uniqueid_generator.get(), i + NUM)
self.uidGen.release_older_than(stamp) # Clear ups older than 0 seconds ago self.uniqueid_generator.release_older_than(stamp) # Clear ups older than 0 seconds ago
for i in range(NUM): for i in range(NUM):
self.assertEqual(self.uidGen.get(), i) self.assertEqual(self.uniqueid_generator.get(), i)
# from NUM to NUM*2-1 (both included) are still there, so we should get 200 # from NUM to NUM*2-1 (both included) are still there, so we should get 200
self.assertEqual(self.uidGen.get(), NUM * 2) self.assertEqual(self.uniqueid_generator.get(), NUM * 2)
self.assertEqual(self.uidGen.get(), NUM * 2 + 1) self.assertEqual(self.uniqueid_generator.get(), NUM * 2 + 1)
def test_gid(self): def test_gid(self) -> None:
for x in range(100): for x in range(100):
self.assertEqual(self.ugidGen.get(), f'uds{x:08d}') self.assertEqual(self.ugidGen.get(), f'uds{x:08d}')
def test_gid_basename(self): def test_gid_basename(self) -> None:
self.ugidGen.set_basename('mar') self.ugidGen.set_basename('mar')
for x in range(100): for x in range(100):
self.assertEqual(self.ugidGen.get(), f'mar{x:08d}') self.assertEqual(self.ugidGen.get(), f'mar{x:08d}')
def test_mac(self): def test_mac(self) -> None:
start, end = TEST_MAC_RANGE.split('-') # pylint: disable=unused-variable start, _end = TEST_MAC_RANGE.split('-') # pylint: disable=unused-variable
self.assertEqual(self.macGen.get(TEST_MAC_RANGE), start) self.assertEqual(self.macGen.get(TEST_MAC_RANGE), start)
starti = macToInt(start) + 1 # We have already got 1 element starti = mac_to_integer(start) + 1 # We have already got 1 element
lst = [start] lst = [start]
for x in range(400): for x in range(400):
mac = self.macGen.get(TEST_MAC_RANGE) mac = self.macGen.get(TEST_MAC_RANGE)
self.assertEqual(macToInt(mac), starti + x) self.assertEqual(mac_to_integer(mac), starti + x)
lst.append(mac) lst.append(mac)
for x in lst: for x in lst:
@ -131,37 +131,37 @@ class UniqueIdTest(UDSTestCase):
self.assertEqual(self.macGen.get(TEST_MAC_RANGE), start) self.assertEqual(self.macGen.get(TEST_MAC_RANGE), start)
def test_mac_full(self): def test_mac_full(self) -> None:
start, end = TEST_MAC_RANGE_FULL.split('-') start, end = TEST_MAC_RANGE_FULL.split('-')
length = macToInt(end) - macToInt(start) + 1 length = mac_to_integer(end) - mac_to_integer(start) + 1
starti = macToInt(start) starti = mac_to_integer(start)
for x in range(length): for x in range(length):
self.assertEqual(macToInt(self.macGen.get(TEST_MAC_RANGE_FULL)), starti + x) self.assertEqual(mac_to_integer(self.macGen.get(TEST_MAC_RANGE_FULL)), starti + x)
for x in range(20): for x in range(20):
self.assertEqual(self.macGen.get(TEST_MAC_RANGE_FULL), '00:00:00:00:00:00') self.assertEqual(self.macGen.get(TEST_MAC_RANGE_FULL), '00:00:00:00:00:00')
def test_name(self): def test_name(self) -> None:
lst = [] lst: list[str] = []
num = 0 num = 0
for length in range(2, 10): for length in range(2, 10):
for x in range(20): for x in range(20):
name = self.nameGen.get('test', length=length) name = self.name_generator.get('test', length=length)
lst.append(name) lst.append(name)
self.assertEqual(name, f'test{num:0{length}d}'.format(num, width=length)) self.assertEqual(name, f'test{num:0{length}d}'.format(num, width=length))
num += 1 num += 1
for x in lst: for x in lst:
self.nameGen.free('test', x) self.name_generator.free('test', x)
self.assertEqual(self.nameGen.get('test', length=1), 'test0') self.assertEqual(self.name_generator.get('test', length=1), 'test0')
def test_name_full(self): def test_name_full(self) -> None:
for _ in range(10): for _ in range(10):
self.nameGen.get('test', length=1) self.name_generator.get('test', length=1)
with self.assertRaises(KeyError): with self.assertRaises(KeyError):
self.nameGen.get('test', length=1) self.name_generator.get('test', length=1)

View File

@ -47,13 +47,13 @@ from ...fixtures import services as fixtures_services
class AssignedAndUnusedTest(UDSTestCase): class AssignedAndUnusedTest(UDSTestCase):
userServices: list[models.UserService] userServices: list[models.UserService]
def setUp(self): def setUp(self) -> None:
config.GlobalConfig.CHECK_UNUSED_TIME.set('600') config.GlobalConfig.CHECK_UNUSED_TIME.set('600')
AssignedAndUnused.setup() AssignedAndUnused.setup()
# All created user services has "in_use" to False, os_state and state to USABLE # All created user services has "in_use" to False, os_state and state to USABLE
self.userServices = fixtures_services.create_cache_testing_userservices(count=32) self.userServices = fixtures_services.create_cache_testing_userservices(count=32)
def test_assigned_unused(self): def test_assigned_unused(self) -> None:
for us in self.userServices: # Update state date to now for us in self.userServices: # Update state date to now
us.set_state(State.USABLE) us.set_state(State.USABLE)
# Set now, should not be removed # Set now, should not be removed

View File

@ -50,7 +50,7 @@ TEST_SERVICES = 5 * 5 # Ensure multiple of 5 for testing
class HangedCleanerTest(UDSTestCase): class HangedCleanerTest(UDSTestCase):
userServices: list[models.UserService] userServices: list[models.UserService]
def setUp(self): def setUp(self) -> None:
config.GlobalConfig.MAX_INITIALIZING_TIME.set(MAX_INIT) config.GlobalConfig.MAX_INITIALIZING_TIME.set(MAX_INIT)
config.GlobalConfig.MAX_REMOVAL_TIME.set(MAX_INIT) config.GlobalConfig.MAX_REMOVAL_TIME.set(MAX_INIT)
HangedCleaner.setup() HangedCleaner.setup()
@ -82,7 +82,7 @@ class HangedCleanerTest(UDSTestCase):
) )
us.save(update_fields=['state', 'os_state', 'state_date']) us.save(update_fields=['state', 'os_state', 'state_date'])
def test_hanged_cleaner(self): def test_hanged_cleaner(self) -> None:
# At start, there is no "removable" user services # At start, there is no "removable" user services
cleaner = HangedCleaner(Environment.testing_environment()) cleaner = HangedCleaner(Environment.testing_environment())
cleaner.run() cleaner.run()

View File

@ -78,7 +78,7 @@ class StuckCleanerTest(UDSTestCase):
us.save(update_fields=['state_date', 'state', 'os_state']) us.save(update_fields=['state_date', 'state', 'os_state'])
def test_worker_outdated(self): def test_worker_outdated(self) -> None:
count = UserService.objects.count() count = UserService.objects.count()
cleaner = StuckCleaner(Environment.testing_environment()) cleaner = StuckCleaner(Environment.testing_environment())
cleaner.run() cleaner.run()
@ -86,7 +86,7 @@ class StuckCleanerTest(UDSTestCase):
UserService.objects.count(), count // 4 UserService.objects.count(), count // 4
) # 3/4 of user services should be removed ) # 3/4 of user services should be removed
def test_worker_not_outdated(self): def test_worker_not_outdated(self) -> None:
# Fix state_date to be less than 1 day for all user services # Fix state_date to be less than 1 day for all user services
for us in self.userServices: for us in self.userServices:
us.state_date = datetime.datetime.now() - datetime.timedelta( us.state_date = datetime.datetime.now() - datetime.timedelta(

View File

@ -44,7 +44,7 @@ class RedirectMiddlewareTest(test.UDSTransactionTestCase):
""" """
Test client functionality Test client functionality
""" """
def test_redirect(self): def test_redirect(self) -> None:
RedirectMiddlewareTest.add_middleware('uds.middleware.redirect.RedirectMiddleware') RedirectMiddlewareTest.add_middleware('uds.middleware.redirect.RedirectMiddleware')
page = 'https://testserver' + reverse('page.index') page = 'https://testserver' + reverse('page.index')
response = self.client.get('/', secure=False) response = self.client.get('/', secure=False)

View File

@ -24,7 +24,7 @@ from uds.services.PhysicalMachines import provider, deployment
class IPMachineUserServiceSerializationTest(UDSTestCase): class IPMachineUserServiceSerializationTest(UDSTestCase):
def test_marshalling(self): def test_marshalling(self) -> None:
obj = deployment.OldIPSerialData() obj = deployment.OldIPSerialData()
obj._ip = '1.1.1.1' obj._ip = '1.1.1.1'
obj._state = 'state' obj._state = 'state'
@ -37,7 +37,7 @@ class IPMachineUserServiceSerializationTest(UDSTestCase):
data = obj.marshal() data = obj.marshal()
instance = deployment.IPMachineUserService(environment=Environment.testing_environment(), service=None) instance = deployment.IPMachineUserService(environment=Environment.testing_environment(), service=None) # type: ignore # service is not used
instance.unmarshal(data) instance.unmarshal(data)
marshaled_data = instance.marshal() marshaled_data = instance.marshal()
@ -53,7 +53,7 @@ class IPMachineUserServiceSerializationTest(UDSTestCase):
_check_fields(instance) _check_fields(instance)
# Reunmarshall again and check that remarshalled flag is not set # Reunmarshall again and check that remarshalled flag is not set
instance = deployment.IPMachineUserService(environment=Environment.testing_environment(), service=None) instance = deployment.IPMachineUserService(environment=Environment.testing_environment(), service=None) # type: ignore # service is not used
instance.unmarshal(marshaled_data) instance.unmarshal(marshaled_data)
self.assertFalse(instance.needs_upgrade()) self.assertFalse(instance.needs_upgrade())

View File

@ -96,7 +96,7 @@ class XenDeploymentSerializationTest(UDSTransactionTestCase):
environment = Environment.testing_environment() environment = Environment.testing_environment()
def _create_instance(unmarshal_data: 'bytes|None' = None) -> Deployment: def _create_instance(unmarshal_data: 'bytes|None' = None) -> Deployment:
instance = Deployment(environment=environment, service=None) instance = Deployment(environment=environment, service=None) # type: ignore # service is not used
if unmarshal_data: if unmarshal_data:
instance.unmarshal(unmarshal_data) instance.unmarshal(unmarshal_data)
return instance return instance
@ -127,7 +127,7 @@ class XenDeploymentSerializationTest(UDSTransactionTestCase):
environment.storage.put_pickle('queue', TEST_QUEUE) environment.storage.put_pickle('queue', TEST_QUEUE)
def _create_instance(unmarshal_data: 'bytes|None' = None) -> Deployment: def _create_instance(unmarshal_data: 'bytes|None' = None) -> Deployment:
instance = Deployment(environment=environment, service=None) instance = Deployment(environment=environment, service=None) # type: ignore # service is not used
if unmarshal_data: if unmarshal_data:
instance.unmarshal(unmarshal_data) instance.unmarshal(unmarshal_data)
return instance return instance
@ -171,6 +171,6 @@ class XenDeploymentSerializationTest(UDSTransactionTestCase):
# This test is designed to ensure that all fields are autoserializable # This test is designed to ensure that all fields are autoserializable
# If some field is added or removed, this tests will warn us about it to fix the rest of the related tests # If some field is added or removed, this tests will warn us about it to fix the rest of the related tests
with Environment.temporary_environment() as env: with Environment.temporary_environment() as env:
instance = Deployment(environment=env, service=None) instance = Deployment(environment=env, service=None) # type: ignore # service is not used
self.assertSetEqual(set(f[0] for f in instance._autoserializable_fields()), EXPECTED_FIELDS) self.assertSetEqual(set(f[0] for f in instance._autoserializable_fields()), EXPECTED_FIELDS)

View File

@ -136,11 +136,11 @@ def random_value(
class RestStruct: class RestStruct:
def __init__(self, **kwargs) -> None: def __init__(self, **kwargs: typing.Any) -> None:
for k, v in kwargs.items(): for k, v in kwargs.items():
setattr(self, k, v) setattr(self, k, v)
def as_dict(self, **kwargs) -> dict[str, typing.Any]: def as_dict(self, **kwargs: typing.Any) -> dict[str, typing.Any]:
# Use kwargs to override values # Use kwargs to override values
res = {k: kwargs.get(k, getattr(self, k)) for k in self.__annotations__} # pylint: disable=no-member res = {k: kwargs.get(k, getattr(self, k)) for k in self.__annotations__} # pylint: disable=no-member
# Remove None values for optional fields # Remove None values for optional fields
@ -158,7 +158,7 @@ class RestStruct:
} }
@classmethod @classmethod
def random_create(cls, **kwargs) -> 'RestStruct': def random_create(cls, **kwargs: typing.Any) -> 'RestStruct':
# Use kwargs to override values # Use kwargs to override values
# Extract type from annotations # Extract type from annotations
return cls(**{k: random_value(v, kwargs.get(k, None)) for k, v in cls.__annotations__.items()}) return cls(**{k: random_value(v, kwargs.get(k, None)) for k, v in cls.__annotations__.items()})

View File

@ -44,8 +44,8 @@ logger = logging.getLogger(__name__)
def assertUserIs( def assertUserIs(
user: models.User, user: models.User,
compare_to: collections.abc.Mapping[str, typing.Any], compare_to: collections.abc.Mapping[str, typing.Any],
compare_uuid=False, compare_uuid: bool=False,
compare_password=False, compare_password: bool=False,
) -> bool: ) -> bool:
ignore_fields = ['password', 'groups', 'mfa_data', 'last_access', 'role'] ignore_fields = ['password', 'groups', 'mfa_data', 'last_access', 'role']
@ -100,7 +100,7 @@ def assertUserIs(
def assertGroupIs( def assertGroupIs(
group: models.Group, compare_to: collections.abc.Mapping[str, typing.Any], compare_uuid=False group: models.Group, compare_to: collections.abc.Mapping[str, typing.Any], compare_uuid: bool=False
) -> bool: ) -> bool:
ignore_fields = ['groups', 'users', 'is_meta', 'type', 'pools'] ignore_fields = ['groups', 'users', 'is_meta', 'type', 'pools']
@ -143,7 +143,7 @@ def assertGroupIs(
def assertServicePoolIs( def assertServicePoolIs(
pool: models.ServicePool, pool: models.ServicePool,
compare_to: collections.abc.Mapping[str, typing.Any], compare_to: collections.abc.Mapping[str, typing.Any],
compare_uuid=False, compare_uuid: bool=False,
) -> bool: ) -> bool:
ignore_fields = [ ignore_fields = [
'tags', 'tags',

View File

@ -26,8 +26,9 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
""" """
@author: Adolfo Gómez, dkmaster at dkmon dot com Author: Adolfo Gómez, dkmaster at dkmon dot com
""" """
# pyright: reportUnknownMemberType=false
import typing import typing
import collections.abc import collections.abc
import logging import logging
@ -50,11 +51,12 @@ class UDSHttpResponse(HttpResponse):
""" """
Custom response class to be able to access the response content Custom response class to be able to access the response content
""" """
url: str url: str
def __init__(self, content: bytes, *args: typing.Any, **kwargs: typing.Any) -> None: def __init__(self, content: bytes, *args: typing.Any, **kwargs: typing.Any) -> None:
super().__init__(content, *args, **kwargs) super().__init__(content, *args, **kwargs)
def json(self) -> typing.Any: def json(self) -> typing.Any:
return super().json() return super().json()
@ -100,7 +102,7 @@ class UDSClientMixin:
kwargs['REMOTE_ADDR'] = '127.0.0.1' kwargs['REMOTE_ADDR'] = '127.0.0.1'
elif self.ip_version == 6: elif self.ip_version == 6:
kwargs['REMOTE_ADDR'] = '::1' kwargs['REMOTE_ADDR'] = '::1'
kwargs['headers'] = self.uds_headers kwargs['headers'] = self.uds_headers
def compose_rest_url(self, method: str) -> str: def compose_rest_url(self, method: str) -> str:
@ -164,7 +166,7 @@ class UDSClient(UDSClientMixin, Client):
return self.delete(self.compose_rest_url(method), *args, **kwargs) return self.delete(self.compose_rest_url(method), *args, **kwargs)
class UDSAsyncClient(UDSClientMixin, AsyncClient): class UDSAsyncClient(UDSClientMixin, AsyncClient): # type: ignore # Django stubs do not include AsyncClient
def __init__( def __init__(
self, self,
enforce_csrf_checks: bool = False, enforce_csrf_checks: bool = False,
@ -173,7 +175,7 @@ class UDSAsyncClient(UDSClientMixin, AsyncClient):
): ):
# Instantiate the client and add basic user agent # Instantiate the client and add basic user agent
AsyncClient.__init__(self, enforce_csrf_checks, raise_request_exception) AsyncClient.__init__(self, enforce_csrf_checks, raise_request_exception) # pyright: ignore
UDSClientMixin.initialize(self) UDSClientMixin.initialize(self)
# and required UDS cookie # and required UDS cookie
@ -184,7 +186,7 @@ class UDSAsyncClient(UDSClientMixin, AsyncClient):
request = request.copy() request = request.copy()
# Add headers # Add headers
request.update(self.uds_headers) request.update(self.uds_headers)
return await super().request(**request) return await super().request(**request) # pyright: ignore
# pylint: disable=invalid-overridden-method # pylint: disable=invalid-overridden-method
async def get(self, *args: typing.Any, **kwargs: typing.Any) -> 'UDSHttpResponse': async def get(self, *args: typing.Any, **kwargs: typing.Any) -> 'UDSHttpResponse':
@ -244,7 +246,7 @@ class UDSTestCaseMixin:
pass # Not present pass # Not present
class UDSTestCase(UDSTestCaseMixin, TestCase): class UDSTestCase(UDSTestCaseMixin, TestCase): # pyright: ignore # Overrides superclass client
@classmethod @classmethod
def setUpClass(cls) -> None: def setUpClass(cls) -> None:
super().setUpClass() super().setUpClass()
@ -253,7 +255,8 @@ class UDSTestCase(UDSTestCaseMixin, TestCase):
def create_environment(self) -> Environment: def create_environment(self) -> Environment:
return Environment.testing_environment() return Environment.testing_environment()
class UDSTransactionTestCase(UDSTestCaseMixin, TransactionTestCase):
class UDSTransactionTestCase(UDSTestCaseMixin, TransactionTestCase): # pyright: ignore # superclass client
@classmethod @classmethod
def setUpClass(cls) -> None: def setUpClass(cls) -> None:
super().setUpClass() super().setUpClass()

View File

@ -77,7 +77,7 @@ def logout(caller: SimpleTestCase, client: Client, auth_token: str) -> None:
response = client.get( response = client.get(
'/uds/rest/auth/logout', '/uds/rest/auth/logout',
content_type='application/json', content_type='application/json',
**{consts.auth.AUTH_TOKEN_HEADER: auth_token} # type: ignore **{consts.auth.AUTH_TOKEN_HEADER: auth_token} # pyright: ignore
) )
caller.assertEqual(response.status_code, 200, 'Logout') caller.assertEqual(response.status_code, 200, 'Logout')
caller.assertEqual(response.json(), {'result': 'ok'}, 'Logout') caller.assertEqual(response.json(), {'result': 'ok'}, 'Logout')

View File

@ -56,7 +56,7 @@ class WebLoginLogoutTest(test.WEBTestCase):
response = typing.cast('HttpResponse', self.client.get('/uds/utility/uds.js')) response = typing.cast('HttpResponse', self.client.get('/uds/utility/uds.js'))
self.assertContains(response, '"errors": ["Access denied"]', status_code=200) self.assertContains(response, '"errors": ["Access denied"]', status_code=200)
def test_login_logout_success(self): def test_login_logout_success(self) -> None:
""" """
Test login and logout Test login and logout
""" """
@ -117,7 +117,7 @@ class WebLoginLogoutTest(test.WEBTestCase):
response = self.do_login('invalid', rootpass, auth.uuid) response = self.do_login('invalid', rootpass, auth.uuid)
self.assertInvalidLogin(response) self.assertInvalidLogin(response)
def test_login_valid_user_no_group(self): def test_login_valid_user_no_group(self) -> None:
user = fixtures_authenticators.create_users( user = fixtures_authenticators.create_users(
fixtures_authenticators.create_authenticator(), fixtures_authenticators.create_authenticator(),
)[0] )[0]
@ -147,7 +147,7 @@ class WebLoginLogoutTest(test.WEBTestCase):
self.assertEqual(models.Log.objects.count(), 12) self.assertEqual(models.Log.objects.count(), 12)
def test_login_invalid_user(self): def test_login_invalid_user(self) -> None:
user = fixtures_authenticators.create_users( user = fixtures_authenticators.create_users(
fixtures_authenticators.create_authenticator(), fixtures_authenticators.create_authenticator(),
)[0] )[0]