diff --git a/server/src/uds/auths/InternalDB/authenticator.py b/server/src/uds/auths/InternalDB/authenticator.py index 1c743437b..c72196647 100644 --- a/server/src/uds/auths/InternalDB/authenticator.py +++ b/server/src/uds/auths/InternalDB/authenticator.py @@ -41,7 +41,7 @@ import dns.reversename from django.utils.translation import gettext_noop as _ from uds.core import auths, types, exceptions, consts -from uds.core.auths.auth import authenticate_log_login +from uds.core.auths.auth import log_login from uds.core.managers.crypto import CryptoManager from uds.core.ui import gui from uds.core.util.state import State @@ -154,7 +154,7 @@ class InternalDBAuth(auths.Authenticator): try: user: 'models.User' = dbAuth.users.get(name=username, state=State.ACTIVE) except Exception: - authenticate_log_login(request, self.db_obj(), username, 'Invalid user') + log_login(request, self.db_obj(), username, 'Invalid user') return types.auth.FAILED_AUTH if user.parent: # Direct auth not allowed for "derived" users @@ -165,7 +165,7 @@ class InternalDBAuth(auths.Authenticator): groupsManager.validate([g.name for g in user.groups.all()]) return types.auth.SUCCESS_AUTH - authenticate_log_login(request, self.db_obj(), username, 'Invalid password') + log_login(request, self.db_obj(), username, 'Invalid password') return types.auth.FAILED_AUTH def get_groups(self, username: str, groupsManager: 'auths.GroupsManager'): diff --git a/server/src/uds/auths/OAuth2/authenticator.py b/server/src/uds/auths/OAuth2/authenticator.py index 9fb2561ba..613024e7a 100644 --- a/server/src/uds/auths/OAuth2/authenticator.py +++ b/server/src/uds/auths/OAuth2/authenticator.py @@ -226,7 +226,7 @@ class OAuth2Authenticator(auths.Authenticator): tab=_('Attributes'), ) - def _getPublicKeys(self) -> list[typing.Any]: # In fact, any of the PublicKey types + def _get_public_keys(self) -> list[typing.Any]: # In fact, any of the PublicKey types # Get certificates in self.publicKey.value, encoded as PEM # Return a list of certificates in DER format if self.publicKey.value.strip() == '': @@ -234,7 +234,7 @@ class OAuth2Authenticator(auths.Authenticator): return [cert.public_key() for cert in fields.get_vertificates_from_field(self.publicKey)] - def _codeVerifierAndChallenge(self) -> tuple[str, str]: + def _code_verifier_and_challenge(self) -> tuple[str, str]: """Generate a code verifier and a code challenge for PKCE Returns: @@ -249,7 +249,7 @@ class OAuth2Authenticator(auths.Authenticator): return codeVerifier, codeChallenge - def _getResponseTypeString(self) -> str: + def _get_response_type_string(self) -> str: match self.responseType.value: case 'code': return 'code' @@ -264,14 +264,14 @@ class OAuth2Authenticator(auths.Authenticator): case _: raise Exception('Invalid response type') - def _getLoginURL(self, request: 'HttpRequest') -> str: + def _get_login_url(self, request: 'HttpRequest') -> str: """ :type request: django.http.request.HttpRequest """ state: str = secrets.token_urlsafe(STATE_LENGTH) param_dict = { - 'response_type': self._getResponseTypeString(), + 'response_type': self._get_response_type_string(), 'client_id': self.clientId.value, 'redirect_uri': self.redirectionEndpoint.value, 'scope': self.scope.value.replace(',', ' '), @@ -296,7 +296,7 @@ class OAuth2Authenticator(auths.Authenticator): case 'pkce': # PKCE flow - codeVerifier, codeChallenge = self._codeVerifierAndChallenge() + codeVerifier, codeChallenge = self._code_verifier_and_challenge() param_dict['code_challenge'] = codeChallenge param_dict['code_challenge_method'] = 'S256' self.cache.put(state, codeVerifier, 3600) @@ -315,7 +315,7 @@ class OAuth2Authenticator(auths.Authenticator): return self.authorizationEndpoint.value + '?' + params - def _requestToken(self, code: str, code_verifier: typing.Optional[str] = None) -> TokenInfo: + def _request_token(self, code: str, code_verifier: typing.Optional[str] = None) -> TokenInfo: """Request a token from the token endpoint using the code received from the authorization endpoint Args: @@ -342,7 +342,7 @@ class OAuth2Authenticator(auths.Authenticator): return TokenInfo.from_dict(req.json()) - def _requestInfo(self, token: 'TokenInfo') -> dict[str, typing.Any]: + def _request_info(self, token: 'TokenInfo') -> dict[str, typing.Any]: """Request user info from the info endpoint using the token received from the token endpoint If the token endpoint returns the user info, this method will not be used @@ -374,7 +374,7 @@ class OAuth2Authenticator(auths.Authenticator): userInfo = req.json() return userInfo - def _processToken( + def _process_token( self, userInfo: collections.abc.Mapping[str, typing.Any], gm: 'auths.GroupsManager' ) -> types.auth.AuthenticationResult: # After this point, we don't mind about the token, we only need to authenticate user @@ -401,7 +401,7 @@ class OAuth2Authenticator(auths.Authenticator): # and if we are here, the user is authenticated, so we can return SUCCESS_AUTH return types.auth.AuthenticationResult(types.auth.AuthenticationState.SUCCESS, username=username) - def _processTokenOpenId( + def _process_token_open_id( self, token_id: str, nonce: str, gm: 'auths.GroupsManager' ) -> types.auth.AuthenticationResult: # Get token headers, to extract algorithm @@ -410,7 +410,7 @@ class OAuth2Authenticator(auths.Authenticator): # We may have multiple public keys, try them all # (We should only have one, but just in case) - for key in self._getPublicKeys(): + for key in self._get_public_keys(): logger.debug('Key = %s', key) try: payload = jwt.decode(token, key=key, audience=self.clientId.value, algorithms=[info.get('alg', 'RSA256')]) # type: ignore @@ -423,7 +423,7 @@ class OAuth2Authenticator(auths.Authenticator): # All is fine, get user & look for groups # Process attributes from payload - return self._processToken(payload, gm) + return self._process_token(payload, gm) except (jwt.InvalidTokenError, IndexError): # logger.debug('Data was invalid: %s', e) pass @@ -477,13 +477,13 @@ class OAuth2Authenticator(auths.Authenticator): ) -> types.auth.AuthenticationResult: match self.responseType.value: case 'code' | 'pkce': - return self.authCallbackCode(parameters, gm, request) + return self.auth_callback_code(parameters, gm, request) case 'token': - return self.authCallbackToken(parameters, gm, request) + return self.auth_callback_token(parameters, gm, request) case 'openid+code': - return self.authCallbackOpenIdCode(parameters, gm, request) + return self.auth_callback_openid_code(parameters, gm, request) case 'openid+token_id': - return self.authCallbackOpenIdIdToken(parameters, gm, request) + return self.authcallback_openid_id_token(parameters, gm, request) case _: raise Exception('Invalid response type') return auths.SUCCESS_AUTH @@ -499,21 +499,21 @@ class OAuth2Authenticator(auths.Authenticator): """ We will here compose the azure request and send it via http-redirect """ - return f'window.location="{self._getLoginURL(request)}";' + return f'window.location="{self._get_login_url(request)}";' def get_groups(self, username: str, groupsManager: 'auths.GroupsManager'): - data = self.storage.getPickle(username) + data = self.storage.get_unpickle(username) if not data: return groupsManager.validate(data[1]) def get_real_name(self, username: str) -> str: - data = self.storage.getPickle(username) + data = self.storage.get_unpickle(username) if not data: return username return data[0] - def authCallbackCode( + def auth_callback_code( self, parameters: 'types.auth.AuthCallbackParams', gm: 'auths.GroupsManager', @@ -539,10 +539,10 @@ class OAuth2Authenticator(auths.Authenticator): if code_verifier == 'none': code_verifier = None - token = self._requestToken(code, code_verifier) - return self._processToken(self._requestInfo(token), gm) + token = self._request_token(code, code_verifier) + return self._process_token(self._request_info(token), gm) - def authCallbackToken( + def auth_callback_token( self, parameters: 'types.auth.AuthCallbackParams', gm: 'auths.GroupsManager', @@ -568,9 +568,9 @@ class OAuth2Authenticator(auths.Authenticator): info={}, id_token=None, ) - return self._processToken(self._requestInfo(token), gm) + return self._process_token(self._request_info(token), gm) - def authCallbackOpenIdCode( + def auth_callback_openid_code( self, parameters: 'types.auth.AuthCallbackParams', gm: 'auths.GroupsManager', @@ -592,15 +592,15 @@ class OAuth2Authenticator(auths.Authenticator): return types.auth.FAILED_AUTH # Get the token, token_type, expires - token = self._requestToken(code) + token = self._request_token(code) if not token.id_token: logger.error('No id_token received on OAuth2 callback') return types.auth.FAILED_AUTH - return self._processTokenOpenId(token.id_token, nonce, gm) + return self._process_token_open_id(token.id_token, nonce, gm) - def authCallbackOpenIdIdToken( + def authcallback_openid_id_token( self, parameters: 'types.auth.AuthCallbackParams', gm: 'auths.GroupsManager', @@ -621,4 +621,4 @@ class OAuth2Authenticator(auths.Authenticator): logger.error('Invalid id_token received on OAuth2 callback') return types.auth.FAILED_AUTH - return self._processTokenOpenId(id_token, nonce, gm) + return self._process_token_open_id(id_token, nonce, gm) diff --git a/server/src/uds/auths/Radius/authenticator.py b/server/src/uds/auths/Radius/authenticator.py index 27ceea5cd..4a704751e 100644 --- a/server/src/uds/auths/Radius/authenticator.py +++ b/server/src/uds/auths/Radius/authenticator.py @@ -37,7 +37,7 @@ import collections.abc from django.utils.translation import gettext_noop as _ from uds.core import auths, types, consts -from uds.core.auths.auth import authenticate_log_login +from uds.core.auths.auth import log_login from uds.core.managers.crypto import CryptoManager from uds.core.ui import gui @@ -125,7 +125,7 @@ class RadiusAuth(auths.Authenticator): def initialize(self, values: typing.Optional[dict[str, typing.Any]]) -> None: pass - def radiusClient(self) -> client.RadiusClient: + def radius_client(self) -> client.RadiusClient: """Return a new radius client .""" return client.RadiusClient( self.server.value, @@ -135,11 +135,11 @@ class RadiusAuth(auths.Authenticator): appClassPrefix=self.appClassPrefix.value, ) - def mfaStorageKey(self, username: str) -> str: + def mfa_storage_key(self, username: str) -> str: return 'mfa_' + str(self.db_obj().uuid) + username def mfa_identifier(self, username: str) -> str: - return self.storage.getPickle(self.mfaStorageKey(username)) or '' + return self.storage.get_unpickle(self.mfa_storage_key(username)) or '' def authenticate( self, @@ -149,7 +149,7 @@ class RadiusAuth(auths.Authenticator): request: 'ExtendedHttpRequest', ) -> types.auth.AuthenticationResult: try: - connection = self.radiusClient() + connection = self.radius_client() groups, mfaCode, state = connection.authenticate(username=username, password=credentials, mfaField=self.mfaAttr.value.strip()) # If state, store in session if state: @@ -157,12 +157,12 @@ class RadiusAuth(auths.Authenticator): # store the user mfa attribute if it is set if mfaCode: self.storage.put_pickle( - self.mfaStorageKey(username), + self.mfa_storage_key(username), mfaCode, ) except Exception: - authenticate_log_login( + log_login( request, self.db_obj(), username, @@ -200,15 +200,15 @@ class RadiusAuth(auths.Authenticator): """Test the connection to the server .""" try: auth = RadiusAuth(None, env, data) # type: ignore - return auth.testConnection() + return auth.test_connection() except Exception as e: logger.error("Exception found testing Radius auth %s: %s", e.__class__, e) return [False, _('Error testing connection')] - def testConnection(self): + def test_connection(self): """Test connection to Radius Server""" try: - connection = self.radiusClient() + connection = self.radius_client() # Reply is not important... connection.authenticate( CryptoManager().random_string(10), CryptoManager().random_string(10) diff --git a/server/src/uds/auths/RegexLdap/authenticator.py b/server/src/uds/auths/RegexLdap/authenticator.py index fd4e26e0c..2a1eae201 100644 --- a/server/src/uds/auths/RegexLdap/authenticator.py +++ b/server/src/uds/auths/RegexLdap/authenticator.py @@ -40,7 +40,7 @@ import ldap from django.utils.translation import gettext_noop as _ from uds.core import auths, exceptions, types, consts -from uds.core.auths.auth import authenticate_log_login +from uds.core.auths.auth import log_login from uds.core.ui import gui from uds.core.util import ldaputil, auth as auth_utils @@ -275,7 +275,7 @@ class RegexLdap(auths.Authenticator): return 'mfa_' + self.db_obj().uuid + username def mfa_identifier(self, username: str) -> str: - return self.storage.getPickle(self.mfaStorageKey(username)) or '' + return self.storage.get_unpickle(self.mfaStorageKey(username)) or '' def dict_of_values(self) -> gui.ValuesDictType: return { @@ -494,14 +494,14 @@ class RegexLdap(auths.Authenticator): usr = self.__getUser(username) if usr is None: - authenticate_log_login(request, self.db_obj(), username, 'Invalid user') + log_login(request, self.db_obj(), username, 'Invalid user') return types.auth.FAILED_AUTH try: # Let's see first if it credentials are fine self.__connectAs(usr['dn'], credentials) # Will raise an exception if it can't connect except Exception: - authenticate_log_login(request, self.db_obj(), username, 'Invalid password') + log_login(request, self.db_obj(), username, 'Invalid password') return types.auth.FAILED_AUTH # store the user mfa attribute if it is set diff --git a/server/src/uds/auths/SAML/saml.py b/server/src/uds/auths/SAML/saml.py index 0e57f7894..593af61f9 100644 --- a/server/src/uds/auths/SAML/saml.py +++ b/server/src/uds/auths/SAML/saml.py @@ -567,7 +567,7 @@ class SAMLAuthenticator(auths.Authenticator): self.storage.remove(self.mfaStorageKey(username)) def mfa_identifier(self, username: str) -> str: - return self.storage.getPickle(self.mfaStorageKey(username)) or '' + return self.storage.get_unpickle(self.mfaStorageKey(username)) or '' def logoutFromCallback( self, @@ -725,13 +725,13 @@ class SAMLAuthenticator(auths.Authenticator): ) def get_groups(self, username: str, groupsManager: 'auths.GroupsManager'): - data = self.storage.getPickle(username) + data = self.storage.get_unpickle(username) if not data: return groupsManager.validate(data[1]) def get_real_name(self, username: str) -> str: - data = self.storage.getPickle(username) + data = self.storage.get_unpickle(username) if not data: return username return data[0] diff --git a/server/src/uds/auths/SimpleLDAP/authenticator.py b/server/src/uds/auths/SimpleLDAP/authenticator.py index b9ff08196..0415d8568 100644 --- a/server/src/uds/auths/SimpleLDAP/authenticator.py +++ b/server/src/uds/auths/SimpleLDAP/authenticator.py @@ -38,7 +38,7 @@ import ldap.filter from django.utils.translation import gettext_noop as _ from uds.core import auths, types, consts, exceptions -from uds.core.auths.auth import authenticate_log_login +from uds.core.auths.auth import log_login from uds.core.ui import gui from uds.core.util import ldaputil @@ -316,7 +316,7 @@ class SimpleLDAPAuthenticator(auths.Authenticator): return 'mfa_' + str(self.db_obj().uuid) + username def mfa_identifier(self, username: str) -> str: - return self.storage.getPickle(self.mfaStorageKey(username)) or '' + return self.storage.get_unpickle(self.mfaStorageKey(username)) or '' def __connection(self): """ @@ -447,14 +447,14 @@ class SimpleLDAPAuthenticator(auths.Authenticator): user = self.__getUser(username) if user is None: - authenticate_log_login(request, self.db_obj(), username, 'Invalid user') + log_login(request, self.db_obj(), username, 'Invalid user') return types.auth.FAILED_AUTH try: # Let's see first if it credentials are fine self.__connectAs(user['dn'], credentials) # Will raise an exception if it can't connect except Exception: - authenticate_log_login(request, self.db_obj(), username, 'Invalid password') + log_login(request, self.db_obj(), username, 'Invalid password') return types.auth.FAILED_AUTH # store the user mfa attribute if it is set diff --git a/server/src/uds/core/auths/auth.py b/server/src/uds/core/auths/auth.py index dfbdcac41..91337873f 100644 --- a/server/src/uds/core/auths/auth.py +++ b/server/src/uds/core/auths/auth.py @@ -513,7 +513,7 @@ def web_logout( return response -def authenticate_log_login( +def log_login( request: 'ExtendedHttpRequest', authenticator: models.Authenticator, userName: str, @@ -558,7 +558,7 @@ def authenticate_log_login( logger.info('Root %s from %s where OS is %s', log_string, request.ip, request.os.os.name) -def auth_log_logout(request: 'ExtendedHttpRequest') -> None: +def log_logout(request: 'ExtendedHttpRequest') -> None: if request.user: if request.user.manager.id is not None: log.log( diff --git a/server/src/uds/core/consts/auth.py b/server/src/uds/core/consts/auth.py index b48f53cf9..23a89eb58 100644 --- a/server/src/uds/core/consts/auth.py +++ b/server/src/uds/core/consts/auth.py @@ -39,8 +39,12 @@ HIDDEN: typing.Final[str] = 'h' DISABLED: typing.Final[str] = 'd' # net_filter -# Note: this are STANDARD values used on "default field" networks on RESP API -# Named them for better reading, but cannot be changed, since they are used on RESP API +# Note: this are STANDARD values used on "default field" networks on REST API +# Named them for better reading, but cannot be changed, since they are used on REST API NO_FILTERING: typing.Final[str] = 'n' ALLOW: typing.Final[str] = 'a' DENY: typing.Final[str] = 'd' + +# Cookie for mfa and csrf field +MFA_COOKIE_NAME: typing.Final[str] = 'mfa_status' +CSRF_FIELD: typing.Final[str] = 'csrfmiddlewaretoken' diff --git a/server/src/uds/core/managers/crypto.py b/server/src/uds/core/managers/crypto.py index 4096e4c28..4b81d24ec 100644 --- a/server/src/uds/core/managers/crypto.py +++ b/server/src/uds/core/managers/crypto.py @@ -315,3 +315,9 @@ class CryptoManager(metaclass=singleton.Singleton): return hashlib.sha3_256( (self.random_string(24, True) + datetime.datetime.now().strftime('%H%M%S%f')).encode() ).hexdigest() + + def sha(self, value: typing.Union[str, bytes]) -> str: + if isinstance(value, str): + value = value.encode() + + return hashlib.sha3_256(value).hexdigest() diff --git a/server/src/uds/core/mfas/mfa.py b/server/src/uds/core/mfas/mfa.py index 6c1ac718f..e67446eea 100644 --- a/server/src/uds/core/mfas/mfa.py +++ b/server/src/uds/core/mfas/mfa.py @@ -215,7 +215,7 @@ class MFA(Module): Internal method to get the data from storage """ storageKey = request.ip + userId - return self.storage.getPickle(storageKey) + return self.storage.get_unpickle(storageKey) def _remove_data(self, request: 'ExtendedHttpRequest', userId: str) -> None: """ diff --git a/server/src/uds/core/services/service.py b/server/src/uds/core/services/service.py index 5ce2e5be0..2eabae614 100644 --- a/server/src/uds/core/services/service.py +++ b/server/src/uds/core/services/service.py @@ -443,7 +443,7 @@ class Service(Module): def recover_id_info(self, id: str, delete: bool = False) -> typing.Any: # recovers the information - value = self.storage.getPickle('__nfo_' + id) + value = self.storage.get_unpickle('__nfo_' + id) if value and delete: self.storage.delete('__nfo_' + id) return value diff --git a/server/src/uds/core/util/cache.py b/server/src/uds/core/util/cache.py index 0253094bf..335a0dfd4 100644 --- a/server/src/uds/core/util/cache.py +++ b/server/src/uds/core/util/cache.py @@ -65,28 +65,24 @@ class Cache: return serializer.deserialize(codecs.decode(value.encode(), 'base64')) _serializer: typing.ClassVar[collections.abc.Callable[[typing.Any], str]] = _basic_serialize - _deserializer: typing.ClassVar[ - collections.abc.Callable[[str], typing.Any] - ] = _basic_deserialize + _deserializer: typing.ClassVar[collections.abc.Callable[[str], typing.Any]] = _basic_deserialize def __init__(self, owner: typing.Union[str, bytes]): self._owner = typing.cast(str, owner.decode('utf-8') if isinstance(owner, bytes) else owner) self._bowner = self._owner.encode('utf8') - def __getKey(self, key: typing.Union[str, bytes]) -> str: + def __get_key(self, key: typing.Union[str, bytes]) -> str: if isinstance(key, str): key = key.encode('utf8') return hash_key(self._bowner + key) - def get( - self, skey: typing.Union[str, bytes], defValue: typing.Any = None - ) -> typing.Any: + def get(self, skey: typing.Union[str, bytes], defValue: typing.Any = None) -> typing.Any: now = sql_datetime() # logger.debug('Requesting key "%s" for cache "%s"', skey, self._owner) try: - key = self.__getKey(skey) + key = self.__get_key(skey) # logger.debug('Key: %s', key) - c: DBCache = DBCache.objects.get(pk=key) # @UndefinedVariable + c: DBCache = DBCache.objects.get(owner=self._owner, pk=key) # If expired if now > c.created + datetime.timedelta(seconds=c.validity): return defValue @@ -105,21 +101,22 @@ class Cache: Cache.misses += 1 # logger.debug('key not found: %s', skey) return defValue - #except OperationalError: - # If database is not ready, just return default value - # This is not a big issue, since cache is not critical - # and probably will be generated by sqlite on high concurrency - + # except OperationalError: + # If database is not ready, just return default value + # This is not a big issue, since cache is not critical + # and probably will be generated by sqlite on high concurrency + # Cache.misses += 1 # return defValue except Exception: import inspect + # Get caller error = 'Error Getting cache key from: ' for caller in inspect.stack(): error += f'{caller.filename}:{caller.lineno} -> ' logger.error(error) - + # logger.exception('Error getting cache key: %s', skey) Cache.misses += 1 return defValue @@ -137,7 +134,7 @@ class Cache: """ # logger.debug('Removing key "%s" for uService "%s"' % (skey, self._owner)) try: - key = self.__getKey(skey) + key = self.__get_key(skey) DBCache.objects.get(pk=key).delete() # @UndefinedVariable return True except DBCache.DoesNotExist: # @UndefinedVariable @@ -163,22 +160,34 @@ class Cache: # logger.debug('Saving key "%s" for cache "%s"' % (skey, self._owner,)) if validity is None: validity = consts.system.DEFAULT_CACHE_TIMEOUT - key = self.__getKey(skey) + key = self.__get_key(skey) strValue = Cache._serializer(value) now = sql_datetime() # Remove existing if any and create a new one with transaction.atomic(): try: - # Remove if existing - DBCache.objects.filter(pk=key).delete() - # And create a new one - DBCache.objects.create( + DBCache.objects.update_or_create( + pk=key, owner=self._owner, - key=key, - value=strValue, - created=now, - validity=validity, - ) # @UndefinedVariable + defaults={ + 'owner': self._owner, + 'key': key, + 'value': strValue, + 'created': now, + 'validity': validity, + }, + ) + + # # Remove if existing + # DBCache.objects.filter(pk=key).delete() + # # And create a new one + # DBCache.objects.create( + # owner=self._owner, + # key=key, + # value=strValue, + # created=now, + # validity=validity, + # ) # @UndefinedVariable return # And return except Exception as e: logger.debug('Transaction in course, cannot store value: %s', e) @@ -192,7 +201,7 @@ class Cache: def refresh(self, skey: typing.Union[str, bytes]) -> None: # logger.debug('Refreshing key "%s" for cache "%s"' % (skey, self._owner,)) try: - key = self.__getKey(skey) + key = self.__get_key(skey) c = DBCache.objects.get(pk=key) c.created = sql_datetime() c.save() diff --git a/server/src/uds/core/util/storage.py b/server/src/uds/core/util/storage.py index 9d6b3a206..b8c4ff02a 100644 --- a/server/src/uds/core/util/storage.py +++ b/server/src/uds/core/util/storage.py @@ -46,7 +46,7 @@ logger = logging.getLogger(__name__) MARK = '_mgb_' -def _calcKey(owner: bytes, key: bytes, extra: typing.Optional[bytes] = None) -> str: +def _calculate_key(owner: bytes, key: bytes, extra: typing.Optional[bytes] = None) -> str: h = hashlib.md5(usedforsecurity=False) h.update(owner) h.update(key) @@ -55,14 +55,14 @@ def _calcKey(owner: bytes, key: bytes, extra: typing.Optional[bytes] = None) -> return h.hexdigest() -def _encodeValue(key: str, value: typing.Any, compat: bool = False) -> str: +def _encode_value(key: str, value: typing.Any, compat: bool = False) -> str: if not compat: return base64.b64encode(pickle.dumps((MARK, key, value))).decode() # Compatibility save return base64.b64encode(pickle.dumps(value)).decode() -def _decodeValue(dbk: str, value: typing.Optional[str]) -> tuple[str, typing.Any]: +def _decode_value(dbk: str, value: typing.Optional[str]) -> tuple[str, typing.Any]: if value: try: v = pickle.loads(base64.b64decode(value.encode())) # nosec: This is e controled pickle loading @@ -121,7 +121,7 @@ class StorageAsDict(MutableMapping): if key[0] == '#': # Compat with old db key return key[1:] - return _calcKey(self._owner.encode(), key.encode()) + return _calculate_key(self._owner.encode(), key.encode()) def __getitem__(self, key: str) -> typing.Any: if not isinstance(key, str): @@ -130,7 +130,11 @@ class StorageAsDict(MutableMapping): dbk = self._key(key) try: c: DBStorage = typing.cast(DBStorage, self._db.get(pk=dbk)) - return _decodeValue(dbk, c.data)[1] # Ignores original key + if c.owner != self._owner: # Maybe a key collision, + logger.error('Key collision detected for key %s', key) + return None + okey, value = _decode_value(dbk, c.data) + return _decode_value(dbk, c.data)[1] # Ignores original key except DBStorage.DoesNotExist: return None @@ -139,7 +143,7 @@ class StorageAsDict(MutableMapping): raise TypeError(f'Key must be str type, {type(key)} found') dbk = self._key(key) - data = _encodeValue(key, value, self._compat) + data = _encode_value(key, value, self._compat) # ignores return value, we don't care if it was created or updated DBStorage.objects.update_or_create( key=dbk, defaults={'data': data, 'attr1': self._group, 'owner': self._owner} @@ -153,7 +157,7 @@ class StorageAsDict(MutableMapping): """ Iterates through keys """ - return iter(_decodeValue(i.key, i.data)[0] for i in self._filtered) + return iter(_decode_value(i.key, i.data)[0] for i in self._filtered) def __contains__(self, key: object) -> bool: if isinstance(key, str): @@ -165,10 +169,10 @@ class StorageAsDict(MutableMapping): # Optimized methods, avoid re-reading from DB def items(self): - return iter(_decodeValue(i.key, i.data) for i in self._filtered) + return iter(_decode_value(i.key, i.data) for i in self._filtered) def values(self): - return iter(_decodeValue(i.key, i.data)[1] for i in self._filtered) + return iter(_decode_value(i.key, i.data)[1] for i in self._filtered) def get(self, key: str, default: typing.Any = None) -> typing.Any: return self[key] or default @@ -191,6 +195,11 @@ class StorageAccess: Allows the access to the storage as a dict, with atomic transaction if requested """ + owner: str + group: typing.Optional[str] + atomic: bool + compat: bool + def __init__( self, owner: str, @@ -223,13 +232,13 @@ class Storage: _bownwer: bytes def __init__(self, owner: typing.Union[str, bytes]): - self._owner = owner.decode('utf-8') if isinstance(owner, bytes) else owner + self._owner = typing.cast(str, owner.decode('utf-8') if isinstance(owner, bytes) else owner) self._bowner = self._owner.encode('utf8') - def getKey(self, key: typing.Union[str, bytes]) -> str: - return _calcKey(self._bowner, key.encode('utf8') if isinstance(key, str) else key) + def get_key(self, key: typing.Union[str, bytes]) -> str: + return _calculate_key(self._bowner, key.encode('utf8') if isinstance(key, str) else key) - def saveData( + def save_to_db( self, skey: typing.Union[str, bytes], data: typing.Any, @@ -240,21 +249,21 @@ class Storage: self.remove(skey) return - key = self.getKey(skey) + key = self.get_key(skey) if isinstance(data, str): data = data.encode('utf-8') - dataStr = codecs.encode(data, 'base64').decode() + data_string = codecs.encode(data, 'base64').decode() attr1 = attr1 or '' try: - DBStorage.objects.create(owner=self._owner, key=key, data=dataStr, attr1=attr1) + DBStorage.objects.create(owner=self._owner, key=key, data=data_string, attr1=attr1) except Exception: with transaction.atomic(): DBStorage.objects.filter(key=key).select_for_update().update( - owner=self._owner, data=dataStr, attr1=attr1 + owner=self._owner, data=data_string, attr1=attr1 ) # @UndefinedVariable def put(self, skey: typing.Union[str, bytes], data: typing.Any) -> None: - return self.saveData(skey, data) + return self.save_to_db(skey, data) def put_pickle( self, @@ -262,23 +271,23 @@ class Storage: data: typing.Any, attr1: typing.Optional[str] = None, ) -> None: - return self.saveData( + return self.save_to_db( skey, pickle.dumps(data), attr1 ) # Protocol 2 is compatible with python 2.7. This will be unnecesary when fully migrated - def updateData( + def update_to_db( self, skey: typing.Union[str, bytes], data: typing.Any, attr1: typing.Optional[str] = None, ) -> None: - self.saveData(skey, data, attr1) + self.save_to_db(skey, data, attr1) - def readData( + def read_from_db( self, skey: typing.Union[str, bytes], fromPickle: bool = False ) -> typing.Optional[typing.Union[str, bytes]]: try: - key = self.getKey(skey) + key = self.get_key(skey) c: DBStorage = DBStorage.objects.get(pk=key) # @UndefinedVariable val = codecs.decode(c.data.encode(), 'base64') @@ -293,15 +302,15 @@ class Storage: return None def get(self, skey: typing.Union[str, bytes]) -> typing.Optional[typing.Union[str, bytes]]: - return self.readData(skey) + return self.read_from_db(skey) - def getPickle(self, skey: typing.Union[str, bytes]) -> typing.Any: - v = self.readData(skey, True) + def get_unpickle(self, skey: typing.Union[str, bytes]) -> typing.Any: + v = self.read_from_db(skey, True) if v: return pickle.loads(typing.cast(bytes, v)) # nosec: This is e controled pickle loading return None - def getPickleByAttr1(self, attr1: str, forUpdate: bool = False): + def get_unpickle_by_attr1(self, attr1: str, forUpdate: bool = False): try: query = DBStorage.objects.filter(owner=self._owner, attr1=attr1) if forUpdate: @@ -312,11 +321,16 @@ class Storage: except Exception: return None - def remove(self, skey: typing.Union[collections.abc.Iterable[typing.Union[str, bytes]], str, bytes]) -> None: - keys: collections.abc.Iterable[typing.Union[str, bytes]] = [skey] if isinstance(skey, (str, bytes)) else skey + def remove( + self, skey: typing.Union[collections.abc.Iterable[typing.Union[str, bytes]], str, bytes] + ) -> None: + keys: collections.abc.Iterable[typing.Union[str, bytes]] = typing.cast( + collections.abc.Iterable[typing.Union[str, bytes]], + [skey] if isinstance(skey, (str, bytes)) else skey, + ) try: # Process several keys at once - DBStorage.objects.filter(key__in=[self.getKey(k) for k in keys]).delete() + DBStorage.objects.filter(key__in=[self.get_key(k) for k in keys]).delete() except Exception: # nosec: Not interested in processing exceptions, just ignores it pass @@ -342,7 +356,9 @@ class Storage: ) -> StorageAccess: return StorageAccess(self._owner, group=group, atomic=atomic, compat=compat) - def locateByAttr1(self, attr1: typing.Union[collections.abc.Iterable[str], str]) -> collections.abc.Iterable[bytes]: + def search_by_attr1( + self, attr1: typing.Union[collections.abc.Iterable[str], str] + ) -> collections.abc.Iterable[bytes]: if isinstance(attr1, str): query = DBStorage.objects.filter(owner=self._owner, attr1=attr1) # @UndefinedVariable else: @@ -365,7 +381,7 @@ class Storage: for v in query: # @UndefinedVariable yield (v.key, codecs.decode(v.data.encode(), 'base64'), v.attr1) - def filterPickle( + def filter_unpickle( self, attr1: typing.Optional[str] = None, forUpdate: bool = False ) -> collections.abc.Iterable[tuple[str, typing.Any, 'str|None']]: for v in self.filter(attr1, forUpdate): diff --git a/server/src/uds/mfas/TOTP/mfa.py b/server/src/uds/mfas/TOTP/mfa.py index fba87e761..54f9c986a 100644 --- a/server/src/uds/mfas/TOTP/mfa.py +++ b/server/src/uds/mfas/TOTP/mfa.py @@ -128,7 +128,7 @@ class TOTP_MFA(mfas.MFA): # Get data from storage related to this user # Data contains the secret and if the user has already logged in already some time # so we show the QR code only once - data: typing.Optional[tuple[str, bool]] = self.storage.getPickle(userId) + data: typing.Optional[tuple[str, bool]] = self.storage.get_unpickle(userId) if data is None: data = (pyotp.random_base32(), False) self._saveUserData(userId, data) diff --git a/server/src/uds/models/transport.py b/server/src/uds/models/transport.py index f2d8f1e46..62d5cd63d 100644 --- a/server/src/uds/models/transport.py +++ b/server/src/uds/models/transport.py @@ -36,7 +36,7 @@ import collections.abc from django.db import models -from uds.core import transports, types +from uds.core import transports, types, consts from uds.core.util import net @@ -58,13 +58,8 @@ class Transport(ManagedObjectModel, TaggingMixin): Sample of transports are RDP, Spice, Web file uploader, etc... """ - # Constants for net_filter - NO_FILTERING = 'n' - ALLOW = 'a' - DENY = 'd' - priority = models.IntegerField(default=0, db_index=True) - net_filtering = models.CharField(max_length=1, default=NO_FILTERING, db_index=True) + net_filtering = models.CharField(max_length=1, default=consts.auth.NO_FILTERING, db_index=True) # We store allowed oss as a comma-separated list allowed_oss = models.CharField(max_length=255, default='') # Label, to group transports on meta pools @@ -130,14 +125,14 @@ class Transport(ManagedObjectModel, TaggingMixin): # Avoid circular import from uds.models import Network # pylint: disable=import-outside-toplevel - if self.net_filtering == Transport.NO_FILTERING: + if self.net_filtering == consts.auth.NO_FILTERING: return True ip, version = net.ip_to_long(ipStr) # Allow exists = self.networks.filter( start__lte=Network.hexlify(ip), end__gte=Network.hexlify(ip), version=version ).exists() - if self.net_filtering == Transport.ALLOW: + if self.net_filtering == consts.auth.ALLOW: return exists # Deny, must not be in any network return not exists diff --git a/server/src/uds/notifiers/telegram/notifier.py b/server/src/uds/notifiers/telegram/notifier.py index e3eae2ae6..985c220e4 100644 --- a/server/src/uds/notifiers/telegram/notifier.py +++ b/server/src/uds/notifiers/telegram/notifier.py @@ -140,7 +140,7 @@ class TelegramNotifier(messaging.Notifier): telegramMsg = f'{group} - {identificator} - {str(level)}: {message}' logger.debug('Sending telegram message: %s', telegramMsg) # load chatIds - chatIds = self.storage.getPickle('chatIds') or [] + chatIds = self.storage.get_unpickle('chatIds') or [] t = telegram.Telegram(self.accessToken.value, self.botname.value) for chatId in chatIds: with ignoreExceptions(): @@ -151,7 +151,7 @@ class TelegramNotifier(messaging.Notifier): def subscribeUser(self, chatId: int) -> None: # we do not expect to have a lot of users, so we will use a simple storage # that holds a list of chatIds - chatIds = self.storage.getPickle('chatIds') or [] + chatIds = self.storage.get_unpickle('chatIds') or [] if chatId not in chatIds: chatIds.append(chatId) self.storage.put_pickle('chatIds', chatIds) @@ -160,7 +160,7 @@ class TelegramNotifier(messaging.Notifier): def unsubscriteUser(self, chatId: int) -> None: # we do not expect to have a lot of users, so we will use a simple storage # that holds a list of chatIds - chatIds = self.storage.getPickle('chatIds') or [] + chatIds = self.storage.get_unpickle('chatIds') or [] if chatId in chatIds: chatIds.remove(chatId) self.storage.put_pickle('chatIds', chatIds) @@ -170,7 +170,7 @@ class TelegramNotifier(messaging.Notifier): if not self.accessToken.value.strip(): return # no access token, no messages # Time of last retrieve - lastCheck: typing.Optional[datetime.datetime] = self.storage.getPickle('lastCheck') + lastCheck: typing.Optional[datetime.datetime] = self.storage.get_unpickle('lastCheck') now = sql_datetime() # If last check is not set, we will set it to now @@ -185,7 +185,7 @@ class TelegramNotifier(messaging.Notifier): # Update last check self.storage.put_pickle('lastCheck', now) - lastOffset = self.storage.getPickle('lastOffset') or 0 + lastOffset = self.storage.get_unpickle('lastOffset') or 0 t = telegram.Telegram(self.accessToken.value, last_offset=lastOffset) with ignoreExceptions(): # In case getUpdates fails, ignore it for update in t.get_updates(): diff --git a/server/src/uds/services/OVirt/jobs.py b/server/src/uds/services/OVirt/jobs.py index 2133189e0..05645904a 100644 --- a/server/src/uds/services/OVirt/jobs.py +++ b/server/src/uds/services/OVirt/jobs.py @@ -71,10 +71,10 @@ class OVirtDeferredRemoval(jobs.Job): if state in ('up', 'powering_up', 'suspended'): providerInstance.stopMachine(vmId) elif state != 'unknown': # Machine exists, remove it later - providerInstance.storage.saveData('tr' + vmId, vmId, attr1='tRm') + providerInstance.storage.save_to_db('tr' + vmId, vmId, attr1='tRm') except Exception as e: - providerInstance.storage.saveData('tr' + vmId, vmId, attr1='tRm') + providerInstance.storage.save_to_db('tr' + vmId, vmId, attr1='tRm') logger.info( 'Machine %s could not be removed right now, queued for later: %s', vmId, diff --git a/server/src/uds/services/PhysicalMachines/service_multi.py b/server/src/uds/services/PhysicalMachines/service_multi.py index ef3feb52e..ca2791d6b 100644 --- a/server/src/uds/services/PhysicalMachines/service_multi.py +++ b/server/src/uds/services/PhysicalMachines/service_multi.py @@ -169,7 +169,7 @@ class IPMachinesService(IPServiceBase): '{}~{}'.format(str(ip).strip(), i) for i, ip in enumerate(values['ipList']) if str(ip).strip() ] # Allow duplicates right now # Current stored data, if it exists - d = self.storage.readData('ips') + d = self.storage.read_from_db('ips') old_ips = pickle.loads(d) if d and isinstance(d, bytes) else [] # nosec: pickle is safe here # dissapeared ones dissapeared = set(IPServiceBase.getIp(i.split('~')[0]) for i in old_ips) - set( @@ -202,7 +202,7 @@ class IPMachinesService(IPServiceBase): } def marshal(self) -> bytes: - self.storage.saveData('ips', pickle.dumps(self._ips)) + self.storage.save_to_db('ips', pickle.dumps(self._ips)) return b'\0'.join( [ b'v7', @@ -217,7 +217,7 @@ class IPMachinesService(IPServiceBase): def unmarshal(self, data: bytes) -> None: values: list[bytes] = data.split(b'\0') - d = self.storage.readData('ips') + d = self.storage.read_from_db('ips') if isinstance(d, bytes): self._ips = pickle.loads(d) # nosec: pickle is safe here elif isinstance(d, str): # "legacy" saved elements @@ -272,7 +272,7 @@ class IPMachinesService(IPServiceBase): for ip in allIps: theIP = IPServiceBase.getIp(ip) theMAC = IPServiceBase.getMac(ip) - locked = self.storage.getPickle(theIP) + locked = self.storage.get_unpickle(theIP) if self.canBeUsed(locked, now): if ( self._port > 0 @@ -320,7 +320,7 @@ class IPMachinesService(IPServiceBase): logger.exception("Exception at getUnassignedMachine") def enumerate_assignables(self): - return [(ip, ip.split('~')[0]) for ip in self._ips if self.storage.readData(ip) is None] + return [(ip, ip.split('~')[0]) for ip in self._ips if self.storage.read_from_db(ip) is None] def assign_from_assignables( self, @@ -333,7 +333,7 @@ class IPMachinesService(IPServiceBase): theMAC = IPServiceBase.getMac(assignableId) now = sql_stamp_seconds() - locked = self.storage.getPickle(theIP) + locked = self.storage.get_unpickle(theIP) if self.canBeUsed(locked, now): self.storage.put_pickle(theIP, now) if theMAC: @@ -351,7 +351,7 @@ class IPMachinesService(IPServiceBase): # Locate the IP on the storage theIP = IPServiceBase.getIp(id) now = sql_stamp_seconds() - locked: typing.Union[None, str, int] = self.storage.getPickle(theIP) + locked: typing.Union[None, str, int] = self.storage.get_unpickle(theIP) if self.canBeUsed(locked, now): self.storage.put_pickle(theIP, str(now)) # Lock it @@ -362,7 +362,7 @@ class IPMachinesService(IPServiceBase): logger.debug('Processing logout for %s: %s', self, id) # Locate the IP on the storage theIP = IPServiceBase.getIp(id) - locked: typing.Union[None, str, int] = self.storage.getPickle(theIP) + locked: typing.Union[None, str, int] = self.storage.get_unpickle(theIP) # If locked is str, has been locked by processLogin so we can unlock it if isinstance(locked, str): self.unassignMachine(id) diff --git a/server/src/uds/services/PhysicalMachines/service_single.py b/server/src/uds/services/PhysicalMachines/service_single.py index 147c70a0a..8dba164f8 100644 --- a/server/src/uds/services/PhysicalMachines/service_single.py +++ b/server/src/uds/services/PhysicalMachines/service_single.py @@ -88,7 +88,7 @@ class IPSingleMachineService(IPServiceBase): def getUnassignedMachine(self) -> typing.Optional[str]: ip: typing.Optional[str] = None try: - counter = self.storage.getPickle('counter') + counter = self.storage.get_unpickle('counter') counter = counter + 1 if counter is not None else 1 self.storage.put_pickle('counter', counter) ip = '{}~{}'.format(self.ip.value, counter) diff --git a/server/src/uds/services/Proxmox/deployment.py b/server/src/uds/services/Proxmox/deployment.py index 6172783c4..1bfc23485 100644 --- a/server/src/uds/services/Proxmox/deployment.py +++ b/server/src/uds/services/Proxmox/deployment.py @@ -513,7 +513,7 @@ if sys.platform == 'win32': """ Check if the machine has gracely stopped (timed shutdown) """ - shutdown_start = self.storage.getPickle('shutdown') + shutdown_start = self.storage.get_unpickle('shutdown') logger.debug('Shutdown start: %s', shutdown_start) if shutdown_start < 0: # Was already stopped # Machine is already stop diff --git a/server/src/uds/services/Proxmox/jobs.py b/server/src/uds/services/Proxmox/jobs.py index f8aa33ae1..f5211f138 100644 --- a/server/src/uds/services/Proxmox/jobs.py +++ b/server/src/uds/services/Proxmox/jobs.py @@ -72,7 +72,7 @@ class ProxmoxDeferredRemoval(jobs.Job): except client.ProxmoxNotFound: return # Machine does not exists except Exception as e: - providerInstance.storage.saveData('tr' + str(vmId), str(vmId), attr1='tRm') + providerInstance.storage.save_to_db('tr' + str(vmId), str(vmId), attr1='tRm') logger.info( 'Machine %s could not be removed right now, queued for later: %s', vmId, diff --git a/server/src/uds/services/Sample/deployment_one.py b/server/src/uds/services/Sample/deployment_one.py index f753d3064..1f755b154 100644 --- a/server/src/uds/services/Sample/deployment_one.py +++ b/server/src/uds/services/Sample/deployment_one.py @@ -117,13 +117,13 @@ class SampleUserServiceOne(services.UserService): a new unique name, so we keep the first generated name cached and don't generate more names. (Generator are simple utility classes) """ - name: str = typing.cast(str, self.storage.readData('name')) + name: str = typing.cast(str, self.storage.read_from_db('name')) if name is None: name = self.name_generator().get( self.service().get_base_name() + '-' + self.service().getColour(), 3 ) # Store value for persistence - self.storage.saveData('name', name) + self.storage.save_to_db('name', name) return name @@ -139,7 +139,7 @@ class SampleUserServiceOne(services.UserService): :note: This IP is the IP of the "consumed service", so the transport can access it. """ - self.storage.saveData('ip', ip) + self.storage.save_to_db('ip', ip) def get_unique_id(self) -> str: """ @@ -151,10 +151,10 @@ class SampleUserServiceOne(services.UserService): The get method of a mac generator takes one param, that is the mac range to use to get an unused mac. """ - mac = typing.cast(str, self.storage.readData('mac')) + mac = typing.cast(str, self.storage.read_from_db('mac')) if mac is None: mac = self.mac_generator().get('00:00:00:00:00:00-00:FF:FF:FF:FF:FF') - self.storage.saveData('mac', mac) + self.storage.save_to_db('mac', mac) return mac def get_ip(self) -> str: @@ -175,7 +175,7 @@ class SampleUserServiceOne(services.UserService): show the IP to the administrator, this method will get called """ - ip = typing.cast(str, self.storage.readData('ip')) + ip = typing.cast(str, self.storage.read_from_db('ip')) if ip is None: ip = '192.168.0.34' # Sample IP for testing purposses only return ip @@ -241,11 +241,11 @@ class SampleUserServiceOne(services.UserService): """ import random - self.storage.saveData('count', '0') + self.storage.save_to_db('count', '0') # random fail if random.randint(0, 9) == 9: # nosec: just testing values - self.storage.saveData('error', 'Random error at deployForUser :-)') + self.storage.save_to_db('error', 'Random error at deployForUser :-)') return State.ERROR return State.RUNNING @@ -274,7 +274,7 @@ class SampleUserServiceOne(services.UserService): import random countStr: typing.Optional[str] = typing.cast( - str, self.storage.readData('count') + str, self.storage.read_from_db('count') ) count: int = 0 if countStr: @@ -288,10 +288,10 @@ class SampleUserServiceOne(services.UserService): # random fail if random.randint(0, 9) == 9: # nosec: just testing values - self.storage.saveData('error', 'Random error at check_state :-)') + self.storage.save_to_db('error', 'Random error at check_state :-)') return State.ERROR - self.storage.saveData('count', str(count)) + self.storage.save_to_db('count', str(count)) return State.RUNNING def finish(self) -> None: @@ -322,7 +322,7 @@ class SampleUserServiceOne(services.UserService): The user provided is just an string, that is provided by actor. """ # We store the value at storage, but never get used, just an example - self.storage.saveData('user', username) + self.storage.save_to_db('user', username) def user_logged_out(self, username) -> None: """ @@ -349,7 +349,7 @@ class SampleUserServiceOne(services.UserService): for it, and it will be asked everytime it's needed to be shown to the user (when the administation asks for it). """ - return typing.cast(str, self.storage.readData('error')) or 'No error' + return typing.cast(str, self.storage.read_from_db('error')) or 'No error' def destroy(self) -> str: """ diff --git a/server/src/uds/services/Sample/deployment_two.py b/server/src/uds/services/Sample/deployment_two.py index c9f7985ef..e09ba5f75 100644 --- a/server/src/uds/services/Sample/deployment_two.py +++ b/server/src/uds/services/Sample/deployment_two.py @@ -388,7 +388,7 @@ class SampleUserServiceTwo(services.UserService): The user provided is just an string, that is provided by actors. """ # We store the value at storage, but never get used, just an example - self.storage.saveData('user', username) + self.storage.save_to_db('user', username) def user_logged_out(self, username) -> None: """ diff --git a/server/src/uds/services/Test/deployment.py b/server/src/uds/services/Test/deployment.py index dfeb74a53..9b68bb8eb 100644 --- a/server/src/uds/services/Test/deployment.py +++ b/server/src/uds/services/Test/deployment.py @@ -99,7 +99,7 @@ class TestUserService(services.UserService): def get_ip(self) -> str: logger.info('Getting ip of deployment %s', self.data) - ip = typing.cast(str, self.storage.readData('ip')) + ip = typing.cast(str, self.storage.read_from_db('ip')) if ip is None: ip = '8.6.4.2' # Sample IP for testing purposses only return ip diff --git a/server/src/uds/urls.py b/server/src/uds/urls.py index df52daaf7..d1660ae51 100644 --- a/server/src/uds/urls.py +++ b/server/src/uds/urls.py @@ -183,7 +183,7 @@ urlpatterns = [ # Services list, ... path( r'uds/webapi/services', - uds.web.views.main.servicesData, + uds.web.views.main.services_data_json, name='webapi.services', ), # Transport own link processor diff --git a/server/src/uds/web/util/authentication.py b/server/src/uds/web/util/authentication.py index cdbecf6ff..27f4b4f9a 100644 --- a/server/src/uds/web/util/authentication.py +++ b/server/src/uds/web/util/authentication.py @@ -35,7 +35,7 @@ from django.http import HttpResponseRedirect from django.utils.translation import gettext as _ -from uds.core.auths.auth import authenticate, authenticate_log_login +from uds.core.auths.auth import authenticate, log_login from uds.models import Authenticator, User from uds.core.util.config import GlobalConfig from uds.core.util.cache import Cache @@ -114,13 +114,13 @@ def check_login( # pylint: disable=too-many-branches, too-many-statements and (tries >= maxTries) or triesByIp >= maxTries ): - authenticate_log_login(request, authenticator, userName, 'Temporarily blocked') + log_login(request, authenticator, userName, 'Temporarily blocked') return LoginResult( errstr=_('Too many authentication errrors. User temporarily blocked') ) # check if authenticator is visible for this requests if authInstance.is_ip_allowed(request=request) is False: - authenticate_log_login( + log_login( request, authenticator, userName, @@ -136,7 +136,7 @@ def check_login( # pylint: disable=too-many-branches, too-many-statements logger.debug("Invalid user %s (access denied)", userName) cache.put(cacheKey, tries + 1, GlobalConfig.LOGIN_BLOCK.getInt()) cache.put(request.ip, triesByIp + 1, GlobalConfig.LOGIN_BLOCK.getInt()) - authenticate_log_login( + log_login( request, authenticator, userName, @@ -154,7 +154,7 @@ def check_login( # pylint: disable=too-many-branches, too-many-statements if form.cleaned_data['logouturl'] != '': logger.debug('The logoout url will be %s', form.cleaned_data['logouturl']) request.session['logouturl'] = form.cleaned_data['logouturl'] - authenticate_log_login(request, authenticator, authResult.user.name) + log_login(request, authenticator, authResult.user.name) return LoginResult(user=authResult.user, password=form.cleaned_data['password']) logger.info('Invalid form received') diff --git a/server/src/uds/web/views/auth.py b/server/src/uds/web/views/auth.py index f9f662a86..a59a1055d 100644 --- a/server/src/uds/web/views/auth.py +++ b/server/src/uds/web/views/auth.py @@ -45,7 +45,7 @@ from uds.core.auths.auth import ( web_login, web_logout, authenticate_via_callback, - authenticate_log_login, + log_login, getUDSCookie, ) from uds.core.managers.user_service import UserServiceManager @@ -111,14 +111,14 @@ def auth_callback_stage2(request: 'ExtendedHttpRequestWithUser', ticketId: str) raise exceptions.auth.Redirect(result.url) if result.user is None: - authenticate_log_login(request, authenticator, f'{params}', 'Invalid at auth callback') + log_login(request, authenticator, f'{params}', 'Invalid at auth callback') raise exceptions.auth.InvalidUserException() response = HttpResponseRedirect(reverse('page.index')) web_login(request, response, result.user, '') # Password is unavailable in this case - authenticate_log_login(request, authenticator, result.user.name, 'Federated login') + log_login(request, authenticator, result.user.name, 'Federated login') # If MFA is provided, we need to redirect to MFA page request.authorized = True @@ -239,7 +239,7 @@ def ticket_auth( web_login(request, None, usr, password) # Log the login - authenticate_log_login(request, auth, username, 'Ticket authentication') + log_login(request, auth, username, 'Ticket authentication') request.user = ( usr # Temporarily store this user as "authenticated" user, next requests will be done using session @@ -256,7 +256,9 @@ def ticket_auth( # Check if servicePool is part of the ticket if poolUuid: # Request service, with transport = None so it is automatic - res = UserServiceManager().get_user_service_info(request.user, request.os, request.ip, poolUuid, None, False) + res = UserServiceManager().get_user_service_info( + request.user, request.os, request.ip, poolUuid, None, False + ) _, userService, _, transport, _ = res transportInstance = transport.get_instance() diff --git a/server/src/uds/web/views/main.py b/server/src/uds/web/views/main.py index 521561529..e410e045d 100644 --- a/server/src/uds/web/views/main.py +++ b/server/src/uds/web/views/main.py @@ -28,46 +28,39 @@ """ Author: Adolfo Gómez, dkmaster at dkmon dot com """ +import collections.abc +import datetime +import json +import logging +import random import re import time -import logging import typing -import collections.abc -import random -import json +from django.http import HttpRequest, HttpResponse, HttpResponseRedirect, JsonResponse from django.middleware import csrf from django.shortcuts import render -from django.views.decorators.csrf import csrf_exempt -from django.http import HttpRequest, HttpResponse, JsonResponse, HttpResponseRedirect -from django.views.decorators.cache import never_cache from django.urls import reverse from django.utils.translation import gettext as _ -from py import log -from uds.core.types.request import ExtendedHttpRequest +from django.views.decorators.cache import never_cache +from django.views.decorators.csrf import csrf_exempt -from uds.core.types.request import ExtendedHttpRequestWithUser +from uds import auths, models +from uds.core import exceptions, mfas, types, consts from uds.core.auths import auth -from uds.core.util import state -from uds.core.util.config import GlobalConfig from uds.core.managers.crypto import CryptoManager from uds.core.managers.user_service import UserServiceManager -from uds.web.util import errors +from uds.core.types.request import ExtendedHttpRequest, ExtendedHttpRequestWithUser +from uds.core.util import config, state, storage +from uds.core.util.model import sql_stamp_seconds from uds.web.forms.LoginForm import LoginForm from uds.web.forms.MFAForm import MFAForm +from uds.web.util import configjs, errors from uds.web.util.authentication import check_login from uds.web.util.services import getServicesData -from uds.web.util import configjs -from uds.core import mfas, types, exceptions -from uds import auths, models -from uds.core.util.model import sql_stamp_seconds - logger = logging.getLogger(__name__) -CSRF_FIELD = 'csrfmiddlewaretoken' -MFA_COOKIE_NAME = 'mfa_status' - if typing.TYPE_CHECKING: pass @@ -82,7 +75,7 @@ def index(request: HttpRequest) -> HttpResponse: response = render( request=request, template_name='uds/modern/index.html', - context={'csrf_field': CSRF_FIELD, 'csrf_token': csrf_token}, + context={'csrf_field': consts.auth.CSRF_FIELD, 'csrf_token': csrf_token}, ) # Ensure UDS cookie is present @@ -153,7 +146,7 @@ def login(request: ExtendedHttpRequest, tag: typing.Optional[str] = None) -> Htt @never_cache @auth.web_login_required(admin=False) def logout(request: ExtendedHttpRequestWithUser) -> HttpResponse: - auth.auth_log_logout(request) + auth.log_logout(request) request.session['restricted'] = False # Remove restricted request.authorized = False logoutResponse = request.user.logout(request) @@ -169,11 +162,11 @@ def js(request: ExtendedHttpRequest) -> HttpResponse: @never_cache @auth.deny_non_authenticated # web_login_required not used here because this is not a web page, but js -def servicesData(request: ExtendedHttpRequestWithUser) -> HttpResponse: +def services_data_json(request: ExtendedHttpRequestWithUser) -> HttpResponse: return JsonResponse(getServicesData(request)) -# The MFA page does not needs CRF token, so we disable it +# The MFA page does not needs CSRF token, so we disable it @csrf_exempt def mfa( request: ExtendedHttpRequest, @@ -182,26 +175,37 @@ def mfa( logger.warning('MFA: No user or user is already authorized') return HttpResponseRedirect(reverse('page.index')) # No user, no MFA - mfaProvider = typing.cast('None|models.MFA', request.user.manager.mfa) - if not mfaProvider: + store: 'storage.Storage' = storage.Storage('mfs') + + mfa_provider = typing.cast('None|models.MFA', request.user.manager.mfa) + if not mfa_provider: logger.warning('MFA: No MFA provider for user') return HttpResponseRedirect(reverse('page.index')) - mfaUserId = mfas.MFA.get_user_id(request.user) + mfa_user_id = mfas.MFA.get_user_id(request.user) # Try to get cookie anc check it - mfaCookie = request.COOKIES.get(MFA_COOKIE_NAME, None) - if mfaCookie == mfaUserId: # Cookie is valid, skip MFA setting authorization - logger.debug('MFA: Cookie is valid, skipping MFA') - request.authorized = True - return HttpResponseRedirect(reverse('page.index')) + mfa_cookie = request.COOKIES.get(consts.auth.MFA_COOKIE_NAME, None) + if mfa_cookie and mfa_provider.remember_device > 0: + stored_user_id: typing.Optional[str] + created: typing.Optional[datetime.datetime] + stored_user_id, created = store.get_unpickle(mfa_cookie) or (None, None) + if ( + stored_user_id + and created + and created + datetime.timedelta(hours=mfa_provider.remember_device) > datetime.datetime.now() + ): + # Cookie is valid, skip MFA setting authorization + logger.debug('MFA: Cookie is valid, skipping MFA') + request.authorized = True + return HttpResponseRedirect(reverse('page.index')) # Obtain MFA data - authInstance = request.user.manager.get_instance() - mfaInstance = typing.cast('mfas.MFA', mfaProvider.get_instance()) + auth_instance = request.user.manager.get_instance() + mfa_instance = typing.cast('mfas.MFA', mfa_provider.get_instance()) # Get validity duration - validity = mfaProvider.validity * 60 + validity = mfa_provider.validity * 60 now = sql_stamp_seconds() start_time = request.session.get('mfa_start_time', now) @@ -211,26 +215,26 @@ def mfa( request.session.flush() # Clear session, and redirect to login return HttpResponseRedirect(reverse('page.login')) - mfaIdentifier = authInstance.mfa_identifier(request.user.name) - label = mfaInstance.label() + mfa_identifier = auth_instance.mfa_identifier(request.user.name) + label = mfa_instance.label() - if not mfaIdentifier: - emtpyIdentifiedAllowed = mfaInstance.allow_login_without_identifier(request) + if not mfa_identifier: + allow_login_without_identifier = mfa_instance.allow_login_without_identifier(request) # can be True, False or None - if emtpyIdentifiedAllowed is True: + if allow_login_without_identifier is True: # Allow login request.authorized = True return HttpResponseRedirect(reverse('page.index')) - if emtpyIdentifiedAllowed is False: + if allow_login_without_identifier is False: # Not allowed to login, redirect to login error page logger.warning( 'MFA identifier not found for user %s on authenticator %s. It is required by MFA %s', request.user.name, request.user.manager.name, - mfaProvider.name, + mfa_provider.name, ) return errors.errorView(request, errors.ACCESS_DENIED) - # None, the authenticator will decide what to do if mfaIdentifier is empty + # None, the authenticator will decide what to do if mfa_identifier is empty tries = request.session.get('mfa_tries', 0) if request.method == 'POST': # User has provided MFA code @@ -238,11 +242,11 @@ def mfa( if form.is_valid(): code = form.cleaned_data['code'] try: - mfaInstance.validate( + mfa_instance.validate( request, - mfaUserId, + mfa_user_id, request.user.name, - mfaIdentifier, + mfa_identifier, code, validity=validity, ) # Will raise MFAError if code is not valid @@ -256,11 +260,17 @@ def mfa( response = HttpResponseRedirect(reverse('page.index')) # If mfaProvider requests to keep MFA code on client, create a mfacookie for this user - if mfaProvider.remember_device > 0 and form.cleaned_data['remember'] is True: + if mfa_provider.remember_device > 0 and form.cleaned_data['remember'] is True: + # Store also cookie locally, to check if remember_device is changed + mfa_cookie = CryptoManager().random_string(96) + store.put_pickle( + mfa_cookie, + (mfa_user_id, now), + ) response.set_cookie( - MFA_COOKIE_NAME, - mfaUserId, - max_age=mfaProvider.remember_device * 60 * 60, + consts.auth.MFA_COOKIE_NAME, + mfa_cookie, + max_age=mfa_provider.remember_device * 60 * 60, ) return response @@ -268,7 +278,7 @@ def mfa( logger.error('MFA error: %s', e) tries += 1 request.session['mfa_tries'] = tries - if tries >= GlobalConfig.MAX_LOGIN_TRIES.getInt(): + if tries >= config.GlobalConfig.MAX_LOGIN_TRIES.getInt(): # Clean session request.session.flush() # Too many tries, redirect to login error page @@ -280,11 +290,11 @@ def mfa( # Make MFA send a code request.session['mfa_tries'] = 0 # Reset tries try: - result = mfaInstance.process( + result = mfa_instance.process( request, - mfaUserId, + mfa_user_id, request.user.name, - mfaIdentifier, + mfa_identifier, validity=validity, ) if result == mfas.MFA.RESULT.ALLOWED: @@ -302,15 +312,15 @@ def mfa( # Compose a nice "XX years, XX months, XX days, XX hours, XX minutes" string from mfaProvider.remember_device remember_device = '' # Remember_device is in hours - if mfaProvider.remember_device > 0: + if mfa_provider.remember_device > 0: # if more than a day, we show days only - if mfaProvider.remember_device >= 24: - remember_device = _('{} days').format(mfaProvider.remember_device // 24) + if mfa_provider.remember_device >= 24: + remember_device = _('{} days').format(mfa_provider.remember_device // 24) else: - remember_device = _('{} hours').format(mfaProvider.remember_device) + remember_device = _('{} hours').format(mfa_provider.remember_device) # Html from MFA provider - mfaHtml = mfaInstance.html(request, mfaUserId, request.user.name) + mfaHtml = mfa_instance.html(request, mfa_user_id, request.user.name) # Redirect to index, but with MFA data request.session['mfa'] = { diff --git a/server/tests/core/util/test_storage.py b/server/tests/core/util/test_storage.py index a2f4c41af..4eb37f265 100644 --- a/server/tests/core/util/test_storage.py +++ b/server/tests/core/util/test_storage.py @@ -43,24 +43,24 @@ class StorageTest(UDSTestCase): storage = Storage(UNICODE_CHARS) storage.put(UNICODE_CHARS, b'chars') - storage.saveData('saveData', UNICODE_CHARS, UNICODE_CHARS) - storage.saveData('saveData2', UNICODE_CHARS_2, UNICODE_CHARS) - storage.saveData('saveData3', UNICODE_CHARS, 'attribute') - storage.saveData('saveData4', UNICODE_CHARS_2, 'attribute') + storage.save_to_db('saveData', UNICODE_CHARS, UNICODE_CHARS) + storage.save_to_db('saveData2', UNICODE_CHARS_2, UNICODE_CHARS) + storage.save_to_db('saveData3', UNICODE_CHARS, 'attribute') + storage.save_to_db('saveData4', UNICODE_CHARS_2, 'attribute') storage.put(b'key', UNICODE_CHARS) storage.put(UNICODE_CHARS_2, UNICODE_CHARS) storage.put_pickle('pickle', VALUE_1) self.assertEqual(storage.get(UNICODE_CHARS), u'chars') # Always returns unicod - self.assertEqual(storage.readData('saveData'), UNICODE_CHARS) - self.assertEqual(storage.readData('saveData2'), UNICODE_CHARS_2) + self.assertEqual(storage.read_from_db('saveData'), UNICODE_CHARS) + self.assertEqual(storage.read_from_db('saveData2'), UNICODE_CHARS_2) self.assertEqual(storage.get(b'key'), UNICODE_CHARS) self.assertEqual(storage.get(UNICODE_CHARS_2), UNICODE_CHARS) - self.assertEqual(storage.getPickle('pickle'), VALUE_1) + self.assertEqual(storage.get_unpickle('pickle'), VALUE_1) - self.assertEqual(len(list(storage.locateByAttr1(UNICODE_CHARS))), 2) - self.assertEqual(len(list(storage.locateByAttr1('attribute'))), 2) + self.assertEqual(len(list(storage.search_by_attr1(UNICODE_CHARS))), 2) + self.assertEqual(len(list(storage.search_by_attr1('attribute'))), 2) storage.remove(UNICODE_CHARS) storage.remove(b'key') @@ -68,4 +68,4 @@ class StorageTest(UDSTestCase): self.assertIsNone(storage.get(UNICODE_CHARS)) self.assertIsNone(storage.get(b'key')) - self.assertIsNone(storage.getPickle('pickle')) + self.assertIsNone(storage.get_unpickle('pickle'))