1
0
mirror of https://github.com/dkmstr/openuds.git synced 2025-01-05 09:17:54 +03:00

Refactor weblogin callback to use secure request session and constant timeout

Extracted client plugins downloadables to its own file for readibility and code maintenance
This commit is contained in:
Adolfo Gómez García 2024-09-28 16:45:14 +02:00
parent f53c9bb793
commit f7df9b2ae8
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
7 changed files with 83 additions and 69 deletions

View File

@ -61,27 +61,27 @@ class AuthCallbackTest(UDSTestCase):
def test_no_callback(self) -> None: def test_no_callback(self) -> None:
config.GlobalConfig.NOTIFY_CALLBACK_URL.set('') # Clean callback url config.GlobalConfig.NOTIFY_CALLBACK_URL.set('') # Clean callback url
with mock.patch('requests.post') as mock_post: with mock.patch('uds.core.util.security.secure_requests_session') as session_mock:
callbacks.weblogin(self.user) callbacks.weblogin(self.user)
mock_post.assert_not_called() session_mock.assert_not_called()
def test_callback_failed_url(self) -> None: def test_callback_failed_url(self) -> None:
config.GlobalConfig.NOTIFY_CALLBACK_URL.set('http://localhost:1234') # Sample non existent url config.GlobalConfig.NOTIFY_CALLBACK_URL.set('http://localhost:1234') # Sample non existent url
callbacks.FAILURE_CACHE.set('notify_failure', 3) # Already failed 3 times callbacks.FAILURE_CACHE.set('notify_failure', 3) # Already failed 3 times
with mock.patch('requests.post') as mock_post: with mock.patch('uds.core.util.security.secure_requests_session') as session_mock:
callbacks.weblogin(self.user) callbacks.weblogin(self.user)
mock_post.assert_not_called() session_mock.assert_not_called()
def test_callback_fails_repeteadly(self) -> None: def test_callback_fails_repeteadly(self) -> None:
config.GlobalConfig.NOTIFY_CALLBACK_URL.set('https://localhost:1234') config.GlobalConfig.NOTIFY_CALLBACK_URL.set('https://localhost:1234')
with mock.patch('requests.post') as mock_post: with mock.patch('uds.core.util.security.secure_requests_session') as session_mock:
mock_post.side_effect = Exception('Error') session_mock.return_value.post.side_effect = Exception('Error')
for _i in range(16): for _i in range(16):
callbacks.weblogin(self.user) callbacks.weblogin(self.user)
self.assertEqual(mock_post.call_count, 3) self.assertEqual(session_mock.call_count, 3)
def test_callback_change_groups(self) -> None: def test_callback_change_groups(self) -> None:
config.GlobalConfig.NOTIFY_CALLBACK_URL.set('https://localhost:1234') config.GlobalConfig.NOTIFY_CALLBACK_URL.set('https://localhost:1234')
@ -91,13 +91,13 @@ class AuthCallbackTest(UDSTestCase):
diff_groups = all_groups - current_groups diff_groups = all_groups - current_groups
with mock.patch('requests.post') as mock_post: with mock.patch('uds.core.util.security.secure_requests_session') as session_mock:
mock_post.return_value.json.return_value = { session_mock.return_value.post.return_value.json.return_value = {
'new_groups': list(diff_groups), 'new_groups': list(diff_groups),
'removed_groups': list(current_groups), 'removed_groups': list(current_groups),
} }
callbacks.weblogin(self.user) callbacks.weblogin(self.user)
self.assertEqual(mock_post.call_count, 1) self.assertEqual(session_mock.call_count, 1)
self.assertEqual({group.name for group in self.user.groups.all()}, diff_groups) self.assertEqual({group.name for group in self.user.groups.all()}, diff_groups)

View File

@ -69,7 +69,7 @@ def weblogin(user: models.User) -> None:
'groups': [group.name for group in user.groups.all()], 'groups': [group.name for group in user.groups.all()],
}, },
}, },
timeout=consts.net.URGENT_REQUEST_TIMEOUT, timeout=consts.net.SHORT_REQUEST_TIMEOUT,
) )
response.raise_for_status() response.raise_for_status()
FAILURE_CACHE.delete('notify_failure') FAILURE_CACHE.delete('notify_failure')

View File

@ -35,7 +35,7 @@ import typing
# Request related timeouts, etc.. # Request related timeouts, etc..
DEFAULT_REQUEST_TIMEOUT: typing.Final[int] = 20 # In seconds DEFAULT_REQUEST_TIMEOUT: typing.Final[int] = 20 # In seconds
DEFAULT_CONNECT_TIMEOUT: typing.Final[int] = 4 # In seconds DEFAULT_CONNECT_TIMEOUT: typing.Final[int] = 4 # In seconds
URGENT_REQUEST_TIMEOUT: typing.Final[int] = 3 # In seconds SHORT_REQUEST_TIMEOUT: typing.Final[int] = 3 # In seconds
# Default UDS Registerd Server listen port # Default UDS Registerd Server listen port
SERVER_DEFAULT_LISTEN_PORT: typing.Final[int] = 43910 SERVER_DEFAULT_LISTEN_PORT: typing.Final[int] = 43910

View File

@ -43,7 +43,7 @@ class ExtendedHttpRequest(HttpRequest):
ip_version: int ip_version: int
ip_proxy: str ip_proxy: str
os: 'types.os.DetectedOsInfo' os: 'types.os.DetectedOsInfo'
user: typing.Optional['User'] # type: ignore user: typing.Optional['User'] # type: ignore # Overrides the user attribute from HttpRequest
authorized: bool authorized: bool

View File

@ -128,6 +128,12 @@ class Authenticator(ManagedObjectModel, TaggingMixin):
# If type is not registered (should be, but maybe a database inconsistence), consider this a "base empty auth" # If type is not registered (should be, but maybe a database inconsistence), consider this a "base empty auth"
return auths.factory().lookup(self.data_type) or auths.Authenticator return auths.factory().lookup(self.data_type) or auths.Authenticator
def type_is_valid(self) -> bool:
"""
Returns if the type of this authenticator exists
"""
return auths.factory().lookup(self.data_type) is not None
def get_or_create_user(self, username: str, realName: typing.Optional[str] = None) -> 'User': def get_or_create_user(self, username: str, realName: typing.Optional[str] = None) -> 'User':
""" """
Used to get or create a new user at database associated with this authenticator. Used to get or create a new user at database associated with this authenticator.

View File

@ -39,12 +39,13 @@ from django.utils.translation import gettext, get_language
from django.urls import reverse from django.urls import reverse
from django.templatetags.static import static from django.templatetags.static import static
from uds.REST.methods.client import CLIENT_VERSION
from uds.core import consts from uds.core import consts
from uds.core.managers import downloads_manager from uds.core.managers import downloads_manager
from uds.core.util.config import GlobalConfig from uds.core.util.config import GlobalConfig
from uds.models import Authenticator, Image, Network, Transport from uds.models import Authenticator, Image, Network, Transport
from . import udsclients_info
# Not imported at runtime, just for type checking # Not imported at runtime, just for type checking
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from uds.core.types.requests import ExtendedHttpRequest from uds.core.types.requests import ExtendedHttpRequest
@ -92,13 +93,13 @@ def uds_js(request: 'ExtendedHttpRequest') -> str:
# the auths for client # the auths for client
def _get_auth_info(auth: Authenticator) -> dict[str, typing.Any]: def _get_auth_info(auth: Authenticator) -> dict[str, typing.Any]:
theType = auth.get_type() auth_type = auth.get_type()
return { return {
'id': auth.uuid, 'id': auth.uuid,
'name': auth.name, 'name': auth.name,
'label': auth.small_name, 'label': auth.small_name,
'priority': auth.priority, 'priority': auth.priority,
'is_custom': theType.is_custom(), 'is_custom': auth_type.is_custom(),
} }
config: dict[str, typing.Any] = { config: dict[str, typing.Any] = {
@ -106,7 +107,7 @@ def uds_js(request: 'ExtendedHttpRequest') -> str:
'version_stamp': consts.system.VERSION_STAMP, 'version_stamp': consts.system.VERSION_STAMP,
'language': get_language(), 'language': get_language(),
'available_languages': [{'id': k, 'name': gettext(v)} for k, v in settings.LANGUAGES], 'available_languages': [{'id': k, 'name': gettext(v)} for k, v in settings.LANGUAGES],
'authenticators': [_get_auth_info(auth) for auth in authenticators if auth.get_type()], 'authenticators': [_get_auth_info(auth) for auth in authenticators if auth.type_is_valid()],
'mfa': request.session.get('mfa', None), 'mfa': request.session.get('mfa', None),
'tag': tag, 'tag': tag,
'os': request.os.os.name, 'os': request.os.os.name,
@ -122,7 +123,7 @@ def uds_js(request: 'ExtendedHttpRequest') -> str:
'launcher_wait_time': 5000, 'launcher_wait_time': 5000,
'messages': { 'messages': {
# Calendar denied message # Calendar denied message
'calendarDenied': GlobalConfig.LIMITED_BY_CALENDAR_TEXT.get().strip() 'calendar_denied': GlobalConfig.LIMITED_BY_CALENDAR_TEXT.get().strip()
or gettext('Access limited by calendar') or gettext('Access limited by calendar')
}, },
'urls': { 'urls': {
@ -167,57 +168,7 @@ def uds_js(request: 'ExtendedHttpRequest') -> str:
'ip_proxy': request.ip_proxy, 'ip_proxy': request.ip_proxy,
} }
# all plugins are under url clients... plugins = udsclients_info.PLUGINS.copy()
plugins = [
{
'url': static('clients/' + url.format(version=CLIENT_VERSION)),
'description': description,
'name': name,
'legacy': legacy,
}
for url, description, name, legacy in (
(
'UDSClientSetup-{version}.exe',
gettext('Windows client'),
'Windows',
False,
),
('UDSClient-{version}.pkg', gettext('Mac OS X client'), 'MacOS', False),
(
'udsclient3_{version}_all.deb',
gettext('Debian based Linux client') + ' ' + gettext('(requires Python-3.9 or newer)'),
'Linux',
False,
),
(
'udsclient3-{version}-1.noarch.rpm',
gettext('RPM based Linux client (Fedora, Suse, ...)')
+ ' '
+ gettext('(requires Python-3.9 or newer)'),
'Linux',
False,
),
(
'udsclient3-x86_64-{version}.tar.gz',
gettext('Binary appimage X86_64 Linux client'),
'Linux',
False,
),
(
'udsclient3-armhf-{version}.tar.gz',
gettext('Binary appimage ARMHF Linux client (Raspberry, ...)'),
'Linux',
False,
),
(
'udsclient3-{version}.tar.gz',
gettext('Generic .tar.gz Linux client') + ' ' + gettext('(requires Python-3.9 or newer)'),
'Linux',
False,
),
)
]
# We can add here custom downloads with something like this: # We can add here custom downloads with something like this:
# plugins.append({ # plugins.append({
# 'url': 'http://www.google.com/coche.exe', # 'url': 'http://www.google.com/coche.exe',

View File

@ -0,0 +1,57 @@
import typing
from django.utils.translation import gettext
from django.templatetags.static import static
from uds.REST.methods.client import CLIENT_VERSION
# all plugins are under url clients...
PLUGINS: typing.Final[list[dict[str, 'str|bool']]] = [
{
'url': static('clients/' + url.format(version=CLIENT_VERSION)),
'description': description,
'name': name,
'legacy': legacy,
}
for url, description, name, legacy in (
(
'UDSClientSetup-{version}.exe',
gettext('Windows client'),
'Windows',
False,
),
('UDSClient-{version}.pkg', gettext('Mac OS X client'), 'MacOS', False),
(
'udsclient3_{version}_all.deb',
gettext('Debian based Linux client') + ' ' + gettext('(requires Python-3.9 or newer)'),
'Linux',
False,
),
(
'udsclient3-{version}-1.noarch.rpm',
gettext('RPM based Linux client (Fedora, Suse, ...)')
+ ' '
+ gettext('(requires Python-3.9 or newer)'),
'Linux',
False,
),
(
'udsclient3-x86_64-{version}.tar.gz',
gettext('Binary appimage X86_64 Linux client'),
'Linux',
False,
),
(
'udsclient3-armhf-{version}.tar.gz',
gettext('Binary appimage ARMHF Linux client (Raspberry, ...)'),
'Linux',
False,
),
(
'udsclient3-{version}.tar.gz',
gettext('Generic .tar.gz Linux client') + ' ' + gettext('(requires Python-3.9 or newer)'),
'Linux',
False,
),
)
]