1
0
mirror of https://github.com/dkmstr/openuds.git synced 2024-12-25 23:21:41 +03:00

Improved mfa code storage on browser and related security

This commit is contained in:
Adolfo Gómez García 2024-01-10 04:35:13 +01:00
parent 1715e1a7a1
commit bf3d36c901
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
29 changed files with 284 additions and 242 deletions

View File

@ -41,7 +41,7 @@ import dns.reversename
from django.utils.translation import gettext_noop as _
from uds.core import auths, types, exceptions, consts
from uds.core.auths.auth import authenticate_log_login
from uds.core.auths.auth import log_login
from uds.core.managers.crypto import CryptoManager
from uds.core.ui import gui
from uds.core.util.state import State
@ -154,7 +154,7 @@ class InternalDBAuth(auths.Authenticator):
try:
user: 'models.User' = dbAuth.users.get(name=username, state=State.ACTIVE)
except Exception:
authenticate_log_login(request, self.db_obj(), username, 'Invalid user')
log_login(request, self.db_obj(), username, 'Invalid user')
return types.auth.FAILED_AUTH
if user.parent: # Direct auth not allowed for "derived" users
@ -165,7 +165,7 @@ class InternalDBAuth(auths.Authenticator):
groupsManager.validate([g.name for g in user.groups.all()])
return types.auth.SUCCESS_AUTH
authenticate_log_login(request, self.db_obj(), username, 'Invalid password')
log_login(request, self.db_obj(), username, 'Invalid password')
return types.auth.FAILED_AUTH
def get_groups(self, username: str, groupsManager: 'auths.GroupsManager'):

View File

@ -226,7 +226,7 @@ class OAuth2Authenticator(auths.Authenticator):
tab=_('Attributes'),
)
def _getPublicKeys(self) -> list[typing.Any]: # In fact, any of the PublicKey types
def _get_public_keys(self) -> list[typing.Any]: # In fact, any of the PublicKey types
# Get certificates in self.publicKey.value, encoded as PEM
# Return a list of certificates in DER format
if self.publicKey.value.strip() == '':
@ -234,7 +234,7 @@ class OAuth2Authenticator(auths.Authenticator):
return [cert.public_key() for cert in fields.get_vertificates_from_field(self.publicKey)]
def _codeVerifierAndChallenge(self) -> tuple[str, str]:
def _code_verifier_and_challenge(self) -> tuple[str, str]:
"""Generate a code verifier and a code challenge for PKCE
Returns:
@ -249,7 +249,7 @@ class OAuth2Authenticator(auths.Authenticator):
return codeVerifier, codeChallenge
def _getResponseTypeString(self) -> str:
def _get_response_type_string(self) -> str:
match self.responseType.value:
case 'code':
return 'code'
@ -264,14 +264,14 @@ class OAuth2Authenticator(auths.Authenticator):
case _:
raise Exception('Invalid response type')
def _getLoginURL(self, request: 'HttpRequest') -> str:
def _get_login_url(self, request: 'HttpRequest') -> str:
"""
:type request: django.http.request.HttpRequest
"""
state: str = secrets.token_urlsafe(STATE_LENGTH)
param_dict = {
'response_type': self._getResponseTypeString(),
'response_type': self._get_response_type_string(),
'client_id': self.clientId.value,
'redirect_uri': self.redirectionEndpoint.value,
'scope': self.scope.value.replace(',', ' '),
@ -296,7 +296,7 @@ class OAuth2Authenticator(auths.Authenticator):
case 'pkce':
# PKCE flow
codeVerifier, codeChallenge = self._codeVerifierAndChallenge()
codeVerifier, codeChallenge = self._code_verifier_and_challenge()
param_dict['code_challenge'] = codeChallenge
param_dict['code_challenge_method'] = 'S256'
self.cache.put(state, codeVerifier, 3600)
@ -315,7 +315,7 @@ class OAuth2Authenticator(auths.Authenticator):
return self.authorizationEndpoint.value + '?' + params
def _requestToken(self, code: str, code_verifier: typing.Optional[str] = None) -> TokenInfo:
def _request_token(self, code: str, code_verifier: typing.Optional[str] = None) -> TokenInfo:
"""Request a token from the token endpoint using the code received from the authorization endpoint
Args:
@ -342,7 +342,7 @@ class OAuth2Authenticator(auths.Authenticator):
return TokenInfo.from_dict(req.json())
def _requestInfo(self, token: 'TokenInfo') -> dict[str, typing.Any]:
def _request_info(self, token: 'TokenInfo') -> dict[str, typing.Any]:
"""Request user info from the info endpoint using the token received from the token endpoint
If the token endpoint returns the user info, this method will not be used
@ -374,7 +374,7 @@ class OAuth2Authenticator(auths.Authenticator):
userInfo = req.json()
return userInfo
def _processToken(
def _process_token(
self, userInfo: collections.abc.Mapping[str, typing.Any], gm: 'auths.GroupsManager'
) -> types.auth.AuthenticationResult:
# After this point, we don't mind about the token, we only need to authenticate user
@ -401,7 +401,7 @@ class OAuth2Authenticator(auths.Authenticator):
# and if we are here, the user is authenticated, so we can return SUCCESS_AUTH
return types.auth.AuthenticationResult(types.auth.AuthenticationState.SUCCESS, username=username)
def _processTokenOpenId(
def _process_token_open_id(
self, token_id: str, nonce: str, gm: 'auths.GroupsManager'
) -> types.auth.AuthenticationResult:
# Get token headers, to extract algorithm
@ -410,7 +410,7 @@ class OAuth2Authenticator(auths.Authenticator):
# We may have multiple public keys, try them all
# (We should only have one, but just in case)
for key in self._getPublicKeys():
for key in self._get_public_keys():
logger.debug('Key = %s', key)
try:
payload = jwt.decode(token, key=key, audience=self.clientId.value, algorithms=[info.get('alg', 'RSA256')]) # type: ignore
@ -423,7 +423,7 @@ class OAuth2Authenticator(auths.Authenticator):
# All is fine, get user & look for groups
# Process attributes from payload
return self._processToken(payload, gm)
return self._process_token(payload, gm)
except (jwt.InvalidTokenError, IndexError):
# logger.debug('Data was invalid: %s', e)
pass
@ -477,13 +477,13 @@ class OAuth2Authenticator(auths.Authenticator):
) -> types.auth.AuthenticationResult:
match self.responseType.value:
case 'code' | 'pkce':
return self.authCallbackCode(parameters, gm, request)
return self.auth_callback_code(parameters, gm, request)
case 'token':
return self.authCallbackToken(parameters, gm, request)
return self.auth_callback_token(parameters, gm, request)
case 'openid+code':
return self.authCallbackOpenIdCode(parameters, gm, request)
return self.auth_callback_openid_code(parameters, gm, request)
case 'openid+token_id':
return self.authCallbackOpenIdIdToken(parameters, gm, request)
return self.authcallback_openid_id_token(parameters, gm, request)
case _:
raise Exception('Invalid response type')
return auths.SUCCESS_AUTH
@ -499,21 +499,21 @@ class OAuth2Authenticator(auths.Authenticator):
"""
We will here compose the azure request and send it via http-redirect
"""
return f'window.location="{self._getLoginURL(request)}";'
return f'window.location="{self._get_login_url(request)}";'
def get_groups(self, username: str, groupsManager: 'auths.GroupsManager'):
data = self.storage.getPickle(username)
data = self.storage.get_unpickle(username)
if not data:
return
groupsManager.validate(data[1])
def get_real_name(self, username: str) -> str:
data = self.storage.getPickle(username)
data = self.storage.get_unpickle(username)
if not data:
return username
return data[0]
def authCallbackCode(
def auth_callback_code(
self,
parameters: 'types.auth.AuthCallbackParams',
gm: 'auths.GroupsManager',
@ -539,10 +539,10 @@ class OAuth2Authenticator(auths.Authenticator):
if code_verifier == 'none':
code_verifier = None
token = self._requestToken(code, code_verifier)
return self._processToken(self._requestInfo(token), gm)
token = self._request_token(code, code_verifier)
return self._process_token(self._request_info(token), gm)
def authCallbackToken(
def auth_callback_token(
self,
parameters: 'types.auth.AuthCallbackParams',
gm: 'auths.GroupsManager',
@ -568,9 +568,9 @@ class OAuth2Authenticator(auths.Authenticator):
info={},
id_token=None,
)
return self._processToken(self._requestInfo(token), gm)
return self._process_token(self._request_info(token), gm)
def authCallbackOpenIdCode(
def auth_callback_openid_code(
self,
parameters: 'types.auth.AuthCallbackParams',
gm: 'auths.GroupsManager',
@ -592,15 +592,15 @@ class OAuth2Authenticator(auths.Authenticator):
return types.auth.FAILED_AUTH
# Get the token, token_type, expires
token = self._requestToken(code)
token = self._request_token(code)
if not token.id_token:
logger.error('No id_token received on OAuth2 callback')
return types.auth.FAILED_AUTH
return self._processTokenOpenId(token.id_token, nonce, gm)
return self._process_token_open_id(token.id_token, nonce, gm)
def authCallbackOpenIdIdToken(
def authcallback_openid_id_token(
self,
parameters: 'types.auth.AuthCallbackParams',
gm: 'auths.GroupsManager',
@ -621,4 +621,4 @@ class OAuth2Authenticator(auths.Authenticator):
logger.error('Invalid id_token received on OAuth2 callback')
return types.auth.FAILED_AUTH
return self._processTokenOpenId(id_token, nonce, gm)
return self._process_token_open_id(id_token, nonce, gm)

View File

@ -37,7 +37,7 @@ import collections.abc
from django.utils.translation import gettext_noop as _
from uds.core import auths, types, consts
from uds.core.auths.auth import authenticate_log_login
from uds.core.auths.auth import log_login
from uds.core.managers.crypto import CryptoManager
from uds.core.ui import gui
@ -125,7 +125,7 @@ class RadiusAuth(auths.Authenticator):
def initialize(self, values: typing.Optional[dict[str, typing.Any]]) -> None:
pass
def radiusClient(self) -> client.RadiusClient:
def radius_client(self) -> client.RadiusClient:
"""Return a new radius client ."""
return client.RadiusClient(
self.server.value,
@ -135,11 +135,11 @@ class RadiusAuth(auths.Authenticator):
appClassPrefix=self.appClassPrefix.value,
)
def mfaStorageKey(self, username: str) -> str:
def mfa_storage_key(self, username: str) -> str:
return 'mfa_' + str(self.db_obj().uuid) + username
def mfa_identifier(self, username: str) -> str:
return self.storage.getPickle(self.mfaStorageKey(username)) or ''
return self.storage.get_unpickle(self.mfa_storage_key(username)) or ''
def authenticate(
self,
@ -149,7 +149,7 @@ class RadiusAuth(auths.Authenticator):
request: 'ExtendedHttpRequest',
) -> types.auth.AuthenticationResult:
try:
connection = self.radiusClient()
connection = self.radius_client()
groups, mfaCode, state = connection.authenticate(username=username, password=credentials, mfaField=self.mfaAttr.value.strip())
# If state, store in session
if state:
@ -157,12 +157,12 @@ class RadiusAuth(auths.Authenticator):
# store the user mfa attribute if it is set
if mfaCode:
self.storage.put_pickle(
self.mfaStorageKey(username),
self.mfa_storage_key(username),
mfaCode,
)
except Exception:
authenticate_log_login(
log_login(
request,
self.db_obj(),
username,
@ -200,15 +200,15 @@ class RadiusAuth(auths.Authenticator):
"""Test the connection to the server ."""
try:
auth = RadiusAuth(None, env, data) # type: ignore
return auth.testConnection()
return auth.test_connection()
except Exception as e:
logger.error("Exception found testing Radius auth %s: %s", e.__class__, e)
return [False, _('Error testing connection')]
def testConnection(self):
def test_connection(self):
"""Test connection to Radius Server"""
try:
connection = self.radiusClient()
connection = self.radius_client()
# Reply is not important...
connection.authenticate(
CryptoManager().random_string(10), CryptoManager().random_string(10)

View File

@ -40,7 +40,7 @@ import ldap
from django.utils.translation import gettext_noop as _
from uds.core import auths, exceptions, types, consts
from uds.core.auths.auth import authenticate_log_login
from uds.core.auths.auth import log_login
from uds.core.ui import gui
from uds.core.util import ldaputil, auth as auth_utils
@ -275,7 +275,7 @@ class RegexLdap(auths.Authenticator):
return 'mfa_' + self.db_obj().uuid + username
def mfa_identifier(self, username: str) -> str:
return self.storage.getPickle(self.mfaStorageKey(username)) or ''
return self.storage.get_unpickle(self.mfaStorageKey(username)) or ''
def dict_of_values(self) -> gui.ValuesDictType:
return {
@ -494,14 +494,14 @@ class RegexLdap(auths.Authenticator):
usr = self.__getUser(username)
if usr is None:
authenticate_log_login(request, self.db_obj(), username, 'Invalid user')
log_login(request, self.db_obj(), username, 'Invalid user')
return types.auth.FAILED_AUTH
try:
# Let's see first if it credentials are fine
self.__connectAs(usr['dn'], credentials) # Will raise an exception if it can't connect
except Exception:
authenticate_log_login(request, self.db_obj(), username, 'Invalid password')
log_login(request, self.db_obj(), username, 'Invalid password')
return types.auth.FAILED_AUTH
# store the user mfa attribute if it is set

View File

@ -567,7 +567,7 @@ class SAMLAuthenticator(auths.Authenticator):
self.storage.remove(self.mfaStorageKey(username))
def mfa_identifier(self, username: str) -> str:
return self.storage.getPickle(self.mfaStorageKey(username)) or ''
return self.storage.get_unpickle(self.mfaStorageKey(username)) or ''
def logoutFromCallback(
self,
@ -725,13 +725,13 @@ class SAMLAuthenticator(auths.Authenticator):
)
def get_groups(self, username: str, groupsManager: 'auths.GroupsManager'):
data = self.storage.getPickle(username)
data = self.storage.get_unpickle(username)
if not data:
return
groupsManager.validate(data[1])
def get_real_name(self, username: str) -> str:
data = self.storage.getPickle(username)
data = self.storage.get_unpickle(username)
if not data:
return username
return data[0]

View File

@ -38,7 +38,7 @@ import ldap.filter
from django.utils.translation import gettext_noop as _
from uds.core import auths, types, consts, exceptions
from uds.core.auths.auth import authenticate_log_login
from uds.core.auths.auth import log_login
from uds.core.ui import gui
from uds.core.util import ldaputil
@ -316,7 +316,7 @@ class SimpleLDAPAuthenticator(auths.Authenticator):
return 'mfa_' + str(self.db_obj().uuid) + username
def mfa_identifier(self, username: str) -> str:
return self.storage.getPickle(self.mfaStorageKey(username)) or ''
return self.storage.get_unpickle(self.mfaStorageKey(username)) or ''
def __connection(self):
"""
@ -447,14 +447,14 @@ class SimpleLDAPAuthenticator(auths.Authenticator):
user = self.__getUser(username)
if user is None:
authenticate_log_login(request, self.db_obj(), username, 'Invalid user')
log_login(request, self.db_obj(), username, 'Invalid user')
return types.auth.FAILED_AUTH
try:
# Let's see first if it credentials are fine
self.__connectAs(user['dn'], credentials) # Will raise an exception if it can't connect
except Exception:
authenticate_log_login(request, self.db_obj(), username, 'Invalid password')
log_login(request, self.db_obj(), username, 'Invalid password')
return types.auth.FAILED_AUTH
# store the user mfa attribute if it is set

View File

@ -513,7 +513,7 @@ def web_logout(
return response
def authenticate_log_login(
def log_login(
request: 'ExtendedHttpRequest',
authenticator: models.Authenticator,
userName: str,
@ -558,7 +558,7 @@ def authenticate_log_login(
logger.info('Root %s from %s where OS is %s', log_string, request.ip, request.os.os.name)
def auth_log_logout(request: 'ExtendedHttpRequest') -> None:
def log_logout(request: 'ExtendedHttpRequest') -> None:
if request.user:
if request.user.manager.id is not None:
log.log(

View File

@ -39,8 +39,12 @@ HIDDEN: typing.Final[str] = 'h'
DISABLED: typing.Final[str] = 'd'
# net_filter
# Note: this are STANDARD values used on "default field" networks on RESP API
# Named them for better reading, but cannot be changed, since they are used on RESP API
# Note: this are STANDARD values used on "default field" networks on REST API
# Named them for better reading, but cannot be changed, since they are used on REST API
NO_FILTERING: typing.Final[str] = 'n'
ALLOW: typing.Final[str] = 'a'
DENY: typing.Final[str] = 'd'
# Cookie for mfa and csrf field
MFA_COOKIE_NAME: typing.Final[str] = 'mfa_status'
CSRF_FIELD: typing.Final[str] = 'csrfmiddlewaretoken'

View File

@ -315,3 +315,9 @@ class CryptoManager(metaclass=singleton.Singleton):
return hashlib.sha3_256(
(self.random_string(24, True) + datetime.datetime.now().strftime('%H%M%S%f')).encode()
).hexdigest()
def sha(self, value: typing.Union[str, bytes]) -> str:
if isinstance(value, str):
value = value.encode()
return hashlib.sha3_256(value).hexdigest()

View File

@ -215,7 +215,7 @@ class MFA(Module):
Internal method to get the data from storage
"""
storageKey = request.ip + userId
return self.storage.getPickle(storageKey)
return self.storage.get_unpickle(storageKey)
def _remove_data(self, request: 'ExtendedHttpRequest', userId: str) -> None:
"""

View File

@ -443,7 +443,7 @@ class Service(Module):
def recover_id_info(self, id: str, delete: bool = False) -> typing.Any:
# recovers the information
value = self.storage.getPickle('__nfo_' + id)
value = self.storage.get_unpickle('__nfo_' + id)
if value and delete:
self.storage.delete('__nfo_' + id)
return value

View File

@ -65,28 +65,24 @@ class Cache:
return serializer.deserialize(codecs.decode(value.encode(), 'base64'))
_serializer: typing.ClassVar[collections.abc.Callable[[typing.Any], str]] = _basic_serialize
_deserializer: typing.ClassVar[
collections.abc.Callable[[str], typing.Any]
] = _basic_deserialize
_deserializer: typing.ClassVar[collections.abc.Callable[[str], typing.Any]] = _basic_deserialize
def __init__(self, owner: typing.Union[str, bytes]):
self._owner = typing.cast(str, owner.decode('utf-8') if isinstance(owner, bytes) else owner)
self._bowner = self._owner.encode('utf8')
def __getKey(self, key: typing.Union[str, bytes]) -> str:
def __get_key(self, key: typing.Union[str, bytes]) -> str:
if isinstance(key, str):
key = key.encode('utf8')
return hash_key(self._bowner + key)
def get(
self, skey: typing.Union[str, bytes], defValue: typing.Any = None
) -> typing.Any:
def get(self, skey: typing.Union[str, bytes], defValue: typing.Any = None) -> typing.Any:
now = sql_datetime()
# logger.debug('Requesting key "%s" for cache "%s"', skey, self._owner)
try:
key = self.__getKey(skey)
key = self.__get_key(skey)
# logger.debug('Key: %s', key)
c: DBCache = DBCache.objects.get(pk=key) # @UndefinedVariable
c: DBCache = DBCache.objects.get(owner=self._owner, pk=key)
# If expired
if now > c.created + datetime.timedelta(seconds=c.validity):
return defValue
@ -105,21 +101,22 @@ class Cache:
Cache.misses += 1
# logger.debug('key not found: %s', skey)
return defValue
#except OperationalError:
# If database is not ready, just return default value
# This is not a big issue, since cache is not critical
# and probably will be generated by sqlite on high concurrency
# except OperationalError:
# If database is not ready, just return default value
# This is not a big issue, since cache is not critical
# and probably will be generated by sqlite on high concurrency
# Cache.misses += 1
# return defValue
except Exception:
import inspect
# Get caller
error = 'Error Getting cache key from: '
for caller in inspect.stack():
error += f'{caller.filename}:{caller.lineno} -> '
logger.error(error)
# logger.exception('Error getting cache key: %s', skey)
Cache.misses += 1
return defValue
@ -137,7 +134,7 @@ class Cache:
"""
# logger.debug('Removing key "%s" for uService "%s"' % (skey, self._owner))
try:
key = self.__getKey(skey)
key = self.__get_key(skey)
DBCache.objects.get(pk=key).delete() # @UndefinedVariable
return True
except DBCache.DoesNotExist: # @UndefinedVariable
@ -163,22 +160,34 @@ class Cache:
# logger.debug('Saving key "%s" for cache "%s"' % (skey, self._owner,))
if validity is None:
validity = consts.system.DEFAULT_CACHE_TIMEOUT
key = self.__getKey(skey)
key = self.__get_key(skey)
strValue = Cache._serializer(value)
now = sql_datetime()
# Remove existing if any and create a new one
with transaction.atomic():
try:
# Remove if existing
DBCache.objects.filter(pk=key).delete()
# And create a new one
DBCache.objects.create(
DBCache.objects.update_or_create(
pk=key,
owner=self._owner,
key=key,
value=strValue,
created=now,
validity=validity,
) # @UndefinedVariable
defaults={
'owner': self._owner,
'key': key,
'value': strValue,
'created': now,
'validity': validity,
},
)
# # Remove if existing
# DBCache.objects.filter(pk=key).delete()
# # And create a new one
# DBCache.objects.create(
# owner=self._owner,
# key=key,
# value=strValue,
# created=now,
# validity=validity,
# ) # @UndefinedVariable
return # And return
except Exception as e:
logger.debug('Transaction in course, cannot store value: %s', e)
@ -192,7 +201,7 @@ class Cache:
def refresh(self, skey: typing.Union[str, bytes]) -> None:
# logger.debug('Refreshing key "%s" for cache "%s"' % (skey, self._owner,))
try:
key = self.__getKey(skey)
key = self.__get_key(skey)
c = DBCache.objects.get(pk=key)
c.created = sql_datetime()
c.save()

View File

@ -46,7 +46,7 @@ logger = logging.getLogger(__name__)
MARK = '_mgb_'
def _calcKey(owner: bytes, key: bytes, extra: typing.Optional[bytes] = None) -> str:
def _calculate_key(owner: bytes, key: bytes, extra: typing.Optional[bytes] = None) -> str:
h = hashlib.md5(usedforsecurity=False)
h.update(owner)
h.update(key)
@ -55,14 +55,14 @@ def _calcKey(owner: bytes, key: bytes, extra: typing.Optional[bytes] = None) ->
return h.hexdigest()
def _encodeValue(key: str, value: typing.Any, compat: bool = False) -> str:
def _encode_value(key: str, value: typing.Any, compat: bool = False) -> str:
if not compat:
return base64.b64encode(pickle.dumps((MARK, key, value))).decode()
# Compatibility save
return base64.b64encode(pickle.dumps(value)).decode()
def _decodeValue(dbk: str, value: typing.Optional[str]) -> tuple[str, typing.Any]:
def _decode_value(dbk: str, value: typing.Optional[str]) -> tuple[str, typing.Any]:
if value:
try:
v = pickle.loads(base64.b64decode(value.encode())) # nosec: This is e controled pickle loading
@ -121,7 +121,7 @@ class StorageAsDict(MutableMapping):
if key[0] == '#':
# Compat with old db key
return key[1:]
return _calcKey(self._owner.encode(), key.encode())
return _calculate_key(self._owner.encode(), key.encode())
def __getitem__(self, key: str) -> typing.Any:
if not isinstance(key, str):
@ -130,7 +130,11 @@ class StorageAsDict(MutableMapping):
dbk = self._key(key)
try:
c: DBStorage = typing.cast(DBStorage, self._db.get(pk=dbk))
return _decodeValue(dbk, c.data)[1] # Ignores original key
if c.owner != self._owner: # Maybe a key collision,
logger.error('Key collision detected for key %s', key)
return None
okey, value = _decode_value(dbk, c.data)
return _decode_value(dbk, c.data)[1] # Ignores original key
except DBStorage.DoesNotExist:
return None
@ -139,7 +143,7 @@ class StorageAsDict(MutableMapping):
raise TypeError(f'Key must be str type, {type(key)} found')
dbk = self._key(key)
data = _encodeValue(key, value, self._compat)
data = _encode_value(key, value, self._compat)
# ignores return value, we don't care if it was created or updated
DBStorage.objects.update_or_create(
key=dbk, defaults={'data': data, 'attr1': self._group, 'owner': self._owner}
@ -153,7 +157,7 @@ class StorageAsDict(MutableMapping):
"""
Iterates through keys
"""
return iter(_decodeValue(i.key, i.data)[0] for i in self._filtered)
return iter(_decode_value(i.key, i.data)[0] for i in self._filtered)
def __contains__(self, key: object) -> bool:
if isinstance(key, str):
@ -165,10 +169,10 @@ class StorageAsDict(MutableMapping):
# Optimized methods, avoid re-reading from DB
def items(self):
return iter(_decodeValue(i.key, i.data) for i in self._filtered)
return iter(_decode_value(i.key, i.data) for i in self._filtered)
def values(self):
return iter(_decodeValue(i.key, i.data)[1] for i in self._filtered)
return iter(_decode_value(i.key, i.data)[1] for i in self._filtered)
def get(self, key: str, default: typing.Any = None) -> typing.Any:
return self[key] or default
@ -191,6 +195,11 @@ class StorageAccess:
Allows the access to the storage as a dict, with atomic transaction if requested
"""
owner: str
group: typing.Optional[str]
atomic: bool
compat: bool
def __init__(
self,
owner: str,
@ -223,13 +232,13 @@ class Storage:
_bownwer: bytes
def __init__(self, owner: typing.Union[str, bytes]):
self._owner = owner.decode('utf-8') if isinstance(owner, bytes) else owner
self._owner = typing.cast(str, owner.decode('utf-8') if isinstance(owner, bytes) else owner)
self._bowner = self._owner.encode('utf8')
def getKey(self, key: typing.Union[str, bytes]) -> str:
return _calcKey(self._bowner, key.encode('utf8') if isinstance(key, str) else key)
def get_key(self, key: typing.Union[str, bytes]) -> str:
return _calculate_key(self._bowner, key.encode('utf8') if isinstance(key, str) else key)
def saveData(
def save_to_db(
self,
skey: typing.Union[str, bytes],
data: typing.Any,
@ -240,21 +249,21 @@ class Storage:
self.remove(skey)
return
key = self.getKey(skey)
key = self.get_key(skey)
if isinstance(data, str):
data = data.encode('utf-8')
dataStr = codecs.encode(data, 'base64').decode()
data_string = codecs.encode(data, 'base64').decode()
attr1 = attr1 or ''
try:
DBStorage.objects.create(owner=self._owner, key=key, data=dataStr, attr1=attr1)
DBStorage.objects.create(owner=self._owner, key=key, data=data_string, attr1=attr1)
except Exception:
with transaction.atomic():
DBStorage.objects.filter(key=key).select_for_update().update(
owner=self._owner, data=dataStr, attr1=attr1
owner=self._owner, data=data_string, attr1=attr1
) # @UndefinedVariable
def put(self, skey: typing.Union[str, bytes], data: typing.Any) -> None:
return self.saveData(skey, data)
return self.save_to_db(skey, data)
def put_pickle(
self,
@ -262,23 +271,23 @@ class Storage:
data: typing.Any,
attr1: typing.Optional[str] = None,
) -> None:
return self.saveData(
return self.save_to_db(
skey, pickle.dumps(data), attr1
) # Protocol 2 is compatible with python 2.7. This will be unnecesary when fully migrated
def updateData(
def update_to_db(
self,
skey: typing.Union[str, bytes],
data: typing.Any,
attr1: typing.Optional[str] = None,
) -> None:
self.saveData(skey, data, attr1)
self.save_to_db(skey, data, attr1)
def readData(
def read_from_db(
self, skey: typing.Union[str, bytes], fromPickle: bool = False
) -> typing.Optional[typing.Union[str, bytes]]:
try:
key = self.getKey(skey)
key = self.get_key(skey)
c: DBStorage = DBStorage.objects.get(pk=key) # @UndefinedVariable
val = codecs.decode(c.data.encode(), 'base64')
@ -293,15 +302,15 @@ class Storage:
return None
def get(self, skey: typing.Union[str, bytes]) -> typing.Optional[typing.Union[str, bytes]]:
return self.readData(skey)
return self.read_from_db(skey)
def getPickle(self, skey: typing.Union[str, bytes]) -> typing.Any:
v = self.readData(skey, True)
def get_unpickle(self, skey: typing.Union[str, bytes]) -> typing.Any:
v = self.read_from_db(skey, True)
if v:
return pickle.loads(typing.cast(bytes, v)) # nosec: This is e controled pickle loading
return None
def getPickleByAttr1(self, attr1: str, forUpdate: bool = False):
def get_unpickle_by_attr1(self, attr1: str, forUpdate: bool = False):
try:
query = DBStorage.objects.filter(owner=self._owner, attr1=attr1)
if forUpdate:
@ -312,11 +321,16 @@ class Storage:
except Exception:
return None
def remove(self, skey: typing.Union[collections.abc.Iterable[typing.Union[str, bytes]], str, bytes]) -> None:
keys: collections.abc.Iterable[typing.Union[str, bytes]] = [skey] if isinstance(skey, (str, bytes)) else skey
def remove(
self, skey: typing.Union[collections.abc.Iterable[typing.Union[str, bytes]], str, bytes]
) -> None:
keys: collections.abc.Iterable[typing.Union[str, bytes]] = typing.cast(
collections.abc.Iterable[typing.Union[str, bytes]],
[skey] if isinstance(skey, (str, bytes)) else skey,
)
try:
# Process several keys at once
DBStorage.objects.filter(key__in=[self.getKey(k) for k in keys]).delete()
DBStorage.objects.filter(key__in=[self.get_key(k) for k in keys]).delete()
except Exception: # nosec: Not interested in processing exceptions, just ignores it
pass
@ -342,7 +356,9 @@ class Storage:
) -> StorageAccess:
return StorageAccess(self._owner, group=group, atomic=atomic, compat=compat)
def locateByAttr1(self, attr1: typing.Union[collections.abc.Iterable[str], str]) -> collections.abc.Iterable[bytes]:
def search_by_attr1(
self, attr1: typing.Union[collections.abc.Iterable[str], str]
) -> collections.abc.Iterable[bytes]:
if isinstance(attr1, str):
query = DBStorage.objects.filter(owner=self._owner, attr1=attr1) # @UndefinedVariable
else:
@ -365,7 +381,7 @@ class Storage:
for v in query: # @UndefinedVariable
yield (v.key, codecs.decode(v.data.encode(), 'base64'), v.attr1)
def filterPickle(
def filter_unpickle(
self, attr1: typing.Optional[str] = None, forUpdate: bool = False
) -> collections.abc.Iterable[tuple[str, typing.Any, 'str|None']]:
for v in self.filter(attr1, forUpdate):

View File

@ -128,7 +128,7 @@ class TOTP_MFA(mfas.MFA):
# Get data from storage related to this user
# Data contains the secret and if the user has already logged in already some time
# so we show the QR code only once
data: typing.Optional[tuple[str, bool]] = self.storage.getPickle(userId)
data: typing.Optional[tuple[str, bool]] = self.storage.get_unpickle(userId)
if data is None:
data = (pyotp.random_base32(), False)
self._saveUserData(userId, data)

View File

@ -36,7 +36,7 @@ import collections.abc
from django.db import models
from uds.core import transports, types
from uds.core import transports, types, consts
from uds.core.util import net
@ -58,13 +58,8 @@ class Transport(ManagedObjectModel, TaggingMixin):
Sample of transports are RDP, Spice, Web file uploader, etc...
"""
# Constants for net_filter
NO_FILTERING = 'n'
ALLOW = 'a'
DENY = 'd'
priority = models.IntegerField(default=0, db_index=True)
net_filtering = models.CharField(max_length=1, default=NO_FILTERING, db_index=True)
net_filtering = models.CharField(max_length=1, default=consts.auth.NO_FILTERING, db_index=True)
# We store allowed oss as a comma-separated list
allowed_oss = models.CharField(max_length=255, default='')
# Label, to group transports on meta pools
@ -130,14 +125,14 @@ class Transport(ManagedObjectModel, TaggingMixin):
# Avoid circular import
from uds.models import Network # pylint: disable=import-outside-toplevel
if self.net_filtering == Transport.NO_FILTERING:
if self.net_filtering == consts.auth.NO_FILTERING:
return True
ip, version = net.ip_to_long(ipStr)
# Allow
exists = self.networks.filter(
start__lte=Network.hexlify(ip), end__gte=Network.hexlify(ip), version=version
).exists()
if self.net_filtering == Transport.ALLOW:
if self.net_filtering == consts.auth.ALLOW:
return exists
# Deny, must not be in any network
return not exists

View File

@ -140,7 +140,7 @@ class TelegramNotifier(messaging.Notifier):
telegramMsg = f'{group} - {identificator} - {str(level)}: {message}'
logger.debug('Sending telegram message: %s', telegramMsg)
# load chatIds
chatIds = self.storage.getPickle('chatIds') or []
chatIds = self.storage.get_unpickle('chatIds') or []
t = telegram.Telegram(self.accessToken.value, self.botname.value)
for chatId in chatIds:
with ignoreExceptions():
@ -151,7 +151,7 @@ class TelegramNotifier(messaging.Notifier):
def subscribeUser(self, chatId: int) -> None:
# we do not expect to have a lot of users, so we will use a simple storage
# that holds a list of chatIds
chatIds = self.storage.getPickle('chatIds') or []
chatIds = self.storage.get_unpickle('chatIds') or []
if chatId not in chatIds:
chatIds.append(chatId)
self.storage.put_pickle('chatIds', chatIds)
@ -160,7 +160,7 @@ class TelegramNotifier(messaging.Notifier):
def unsubscriteUser(self, chatId: int) -> None:
# we do not expect to have a lot of users, so we will use a simple storage
# that holds a list of chatIds
chatIds = self.storage.getPickle('chatIds') or []
chatIds = self.storage.get_unpickle('chatIds') or []
if chatId in chatIds:
chatIds.remove(chatId)
self.storage.put_pickle('chatIds', chatIds)
@ -170,7 +170,7 @@ class TelegramNotifier(messaging.Notifier):
if not self.accessToken.value.strip():
return # no access token, no messages
# Time of last retrieve
lastCheck: typing.Optional[datetime.datetime] = self.storage.getPickle('lastCheck')
lastCheck: typing.Optional[datetime.datetime] = self.storage.get_unpickle('lastCheck')
now = sql_datetime()
# If last check is not set, we will set it to now
@ -185,7 +185,7 @@ class TelegramNotifier(messaging.Notifier):
# Update last check
self.storage.put_pickle('lastCheck', now)
lastOffset = self.storage.getPickle('lastOffset') or 0
lastOffset = self.storage.get_unpickle('lastOffset') or 0
t = telegram.Telegram(self.accessToken.value, last_offset=lastOffset)
with ignoreExceptions(): # In case getUpdates fails, ignore it
for update in t.get_updates():

View File

@ -71,10 +71,10 @@ class OVirtDeferredRemoval(jobs.Job):
if state in ('up', 'powering_up', 'suspended'):
providerInstance.stopMachine(vmId)
elif state != 'unknown': # Machine exists, remove it later
providerInstance.storage.saveData('tr' + vmId, vmId, attr1='tRm')
providerInstance.storage.save_to_db('tr' + vmId, vmId, attr1='tRm')
except Exception as e:
providerInstance.storage.saveData('tr' + vmId, vmId, attr1='tRm')
providerInstance.storage.save_to_db('tr' + vmId, vmId, attr1='tRm')
logger.info(
'Machine %s could not be removed right now, queued for later: %s',
vmId,

View File

@ -169,7 +169,7 @@ class IPMachinesService(IPServiceBase):
'{}~{}'.format(str(ip).strip(), i) for i, ip in enumerate(values['ipList']) if str(ip).strip()
] # Allow duplicates right now
# Current stored data, if it exists
d = self.storage.readData('ips')
d = self.storage.read_from_db('ips')
old_ips = pickle.loads(d) if d and isinstance(d, bytes) else [] # nosec: pickle is safe here
# dissapeared ones
dissapeared = set(IPServiceBase.getIp(i.split('~')[0]) for i in old_ips) - set(
@ -202,7 +202,7 @@ class IPMachinesService(IPServiceBase):
}
def marshal(self) -> bytes:
self.storage.saveData('ips', pickle.dumps(self._ips))
self.storage.save_to_db('ips', pickle.dumps(self._ips))
return b'\0'.join(
[
b'v7',
@ -217,7 +217,7 @@ class IPMachinesService(IPServiceBase):
def unmarshal(self, data: bytes) -> None:
values: list[bytes] = data.split(b'\0')
d = self.storage.readData('ips')
d = self.storage.read_from_db('ips')
if isinstance(d, bytes):
self._ips = pickle.loads(d) # nosec: pickle is safe here
elif isinstance(d, str): # "legacy" saved elements
@ -272,7 +272,7 @@ class IPMachinesService(IPServiceBase):
for ip in allIps:
theIP = IPServiceBase.getIp(ip)
theMAC = IPServiceBase.getMac(ip)
locked = self.storage.getPickle(theIP)
locked = self.storage.get_unpickle(theIP)
if self.canBeUsed(locked, now):
if (
self._port > 0
@ -320,7 +320,7 @@ class IPMachinesService(IPServiceBase):
logger.exception("Exception at getUnassignedMachine")
def enumerate_assignables(self):
return [(ip, ip.split('~')[0]) for ip in self._ips if self.storage.readData(ip) is None]
return [(ip, ip.split('~')[0]) for ip in self._ips if self.storage.read_from_db(ip) is None]
def assign_from_assignables(
self,
@ -333,7 +333,7 @@ class IPMachinesService(IPServiceBase):
theMAC = IPServiceBase.getMac(assignableId)
now = sql_stamp_seconds()
locked = self.storage.getPickle(theIP)
locked = self.storage.get_unpickle(theIP)
if self.canBeUsed(locked, now):
self.storage.put_pickle(theIP, now)
if theMAC:
@ -351,7 +351,7 @@ class IPMachinesService(IPServiceBase):
# Locate the IP on the storage
theIP = IPServiceBase.getIp(id)
now = sql_stamp_seconds()
locked: typing.Union[None, str, int] = self.storage.getPickle(theIP)
locked: typing.Union[None, str, int] = self.storage.get_unpickle(theIP)
if self.canBeUsed(locked, now):
self.storage.put_pickle(theIP, str(now)) # Lock it
@ -362,7 +362,7 @@ class IPMachinesService(IPServiceBase):
logger.debug('Processing logout for %s: %s', self, id)
# Locate the IP on the storage
theIP = IPServiceBase.getIp(id)
locked: typing.Union[None, str, int] = self.storage.getPickle(theIP)
locked: typing.Union[None, str, int] = self.storage.get_unpickle(theIP)
# If locked is str, has been locked by processLogin so we can unlock it
if isinstance(locked, str):
self.unassignMachine(id)

View File

@ -88,7 +88,7 @@ class IPSingleMachineService(IPServiceBase):
def getUnassignedMachine(self) -> typing.Optional[str]:
ip: typing.Optional[str] = None
try:
counter = self.storage.getPickle('counter')
counter = self.storage.get_unpickle('counter')
counter = counter + 1 if counter is not None else 1
self.storage.put_pickle('counter', counter)
ip = '{}~{}'.format(self.ip.value, counter)

View File

@ -513,7 +513,7 @@ if sys.platform == 'win32':
"""
Check if the machine has gracely stopped (timed shutdown)
"""
shutdown_start = self.storage.getPickle('shutdown')
shutdown_start = self.storage.get_unpickle('shutdown')
logger.debug('Shutdown start: %s', shutdown_start)
if shutdown_start < 0: # Was already stopped
# Machine is already stop

View File

@ -72,7 +72,7 @@ class ProxmoxDeferredRemoval(jobs.Job):
except client.ProxmoxNotFound:
return # Machine does not exists
except Exception as e:
providerInstance.storage.saveData('tr' + str(vmId), str(vmId), attr1='tRm')
providerInstance.storage.save_to_db('tr' + str(vmId), str(vmId), attr1='tRm')
logger.info(
'Machine %s could not be removed right now, queued for later: %s',
vmId,

View File

@ -117,13 +117,13 @@ class SampleUserServiceOne(services.UserService):
a new unique name, so we keep the first generated name cached and don't
generate more names. (Generator are simple utility classes)
"""
name: str = typing.cast(str, self.storage.readData('name'))
name: str = typing.cast(str, self.storage.read_from_db('name'))
if name is None:
name = self.name_generator().get(
self.service().get_base_name() + '-' + self.service().getColour(), 3
)
# Store value for persistence
self.storage.saveData('name', name)
self.storage.save_to_db('name', name)
return name
@ -139,7 +139,7 @@ class SampleUserServiceOne(services.UserService):
:note: This IP is the IP of the "consumed service", so the transport can
access it.
"""
self.storage.saveData('ip', ip)
self.storage.save_to_db('ip', ip)
def get_unique_id(self) -> str:
"""
@ -151,10 +151,10 @@ class SampleUserServiceOne(services.UserService):
The get method of a mac generator takes one param, that is the mac range
to use to get an unused mac.
"""
mac = typing.cast(str, self.storage.readData('mac'))
mac = typing.cast(str, self.storage.read_from_db('mac'))
if mac is None:
mac = self.mac_generator().get('00:00:00:00:00:00-00:FF:FF:FF:FF:FF')
self.storage.saveData('mac', mac)
self.storage.save_to_db('mac', mac)
return mac
def get_ip(self) -> str:
@ -175,7 +175,7 @@ class SampleUserServiceOne(services.UserService):
show the IP to the administrator, this method will get called
"""
ip = typing.cast(str, self.storage.readData('ip'))
ip = typing.cast(str, self.storage.read_from_db('ip'))
if ip is None:
ip = '192.168.0.34' # Sample IP for testing purposses only
return ip
@ -241,11 +241,11 @@ class SampleUserServiceOne(services.UserService):
"""
import random
self.storage.saveData('count', '0')
self.storage.save_to_db('count', '0')
# random fail
if random.randint(0, 9) == 9: # nosec: just testing values
self.storage.saveData('error', 'Random error at deployForUser :-)')
self.storage.save_to_db('error', 'Random error at deployForUser :-)')
return State.ERROR
return State.RUNNING
@ -274,7 +274,7 @@ class SampleUserServiceOne(services.UserService):
import random
countStr: typing.Optional[str] = typing.cast(
str, self.storage.readData('count')
str, self.storage.read_from_db('count')
)
count: int = 0
if countStr:
@ -288,10 +288,10 @@ class SampleUserServiceOne(services.UserService):
# random fail
if random.randint(0, 9) == 9: # nosec: just testing values
self.storage.saveData('error', 'Random error at check_state :-)')
self.storage.save_to_db('error', 'Random error at check_state :-)')
return State.ERROR
self.storage.saveData('count', str(count))
self.storage.save_to_db('count', str(count))
return State.RUNNING
def finish(self) -> None:
@ -322,7 +322,7 @@ class SampleUserServiceOne(services.UserService):
The user provided is just an string, that is provided by actor.
"""
# We store the value at storage, but never get used, just an example
self.storage.saveData('user', username)
self.storage.save_to_db('user', username)
def user_logged_out(self, username) -> None:
"""
@ -349,7 +349,7 @@ class SampleUserServiceOne(services.UserService):
for it, and it will be asked everytime it's needed to be shown to the
user (when the administation asks for it).
"""
return typing.cast(str, self.storage.readData('error')) or 'No error'
return typing.cast(str, self.storage.read_from_db('error')) or 'No error'
def destroy(self) -> str:
"""

View File

@ -388,7 +388,7 @@ class SampleUserServiceTwo(services.UserService):
The user provided is just an string, that is provided by actors.
"""
# We store the value at storage, but never get used, just an example
self.storage.saveData('user', username)
self.storage.save_to_db('user', username)
def user_logged_out(self, username) -> None:
"""

View File

@ -99,7 +99,7 @@ class TestUserService(services.UserService):
def get_ip(self) -> str:
logger.info('Getting ip of deployment %s', self.data)
ip = typing.cast(str, self.storage.readData('ip'))
ip = typing.cast(str, self.storage.read_from_db('ip'))
if ip is None:
ip = '8.6.4.2' # Sample IP for testing purposses only
return ip

View File

@ -183,7 +183,7 @@ urlpatterns = [
# Services list, ...
path(
r'uds/webapi/services',
uds.web.views.main.servicesData,
uds.web.views.main.services_data_json,
name='webapi.services',
),
# Transport own link processor

View File

@ -35,7 +35,7 @@ from django.http import HttpResponseRedirect
from django.utils.translation import gettext as _
from uds.core.auths.auth import authenticate, authenticate_log_login
from uds.core.auths.auth import authenticate, log_login
from uds.models import Authenticator, User
from uds.core.util.config import GlobalConfig
from uds.core.util.cache import Cache
@ -114,13 +114,13 @@ def check_login( # pylint: disable=too-many-branches, too-many-statements
and (tries >= maxTries)
or triesByIp >= maxTries
):
authenticate_log_login(request, authenticator, userName, 'Temporarily blocked')
log_login(request, authenticator, userName, 'Temporarily blocked')
return LoginResult(
errstr=_('Too many authentication errrors. User temporarily blocked')
)
# check if authenticator is visible for this requests
if authInstance.is_ip_allowed(request=request) is False:
authenticate_log_login(
log_login(
request,
authenticator,
userName,
@ -136,7 +136,7 @@ def check_login( # pylint: disable=too-many-branches, too-many-statements
logger.debug("Invalid user %s (access denied)", userName)
cache.put(cacheKey, tries + 1, GlobalConfig.LOGIN_BLOCK.getInt())
cache.put(request.ip, triesByIp + 1, GlobalConfig.LOGIN_BLOCK.getInt())
authenticate_log_login(
log_login(
request,
authenticator,
userName,
@ -154,7 +154,7 @@ def check_login( # pylint: disable=too-many-branches, too-many-statements
if form.cleaned_data['logouturl'] != '':
logger.debug('The logoout url will be %s', form.cleaned_data['logouturl'])
request.session['logouturl'] = form.cleaned_data['logouturl']
authenticate_log_login(request, authenticator, authResult.user.name)
log_login(request, authenticator, authResult.user.name)
return LoginResult(user=authResult.user, password=form.cleaned_data['password'])
logger.info('Invalid form received')

View File

@ -45,7 +45,7 @@ from uds.core.auths.auth import (
web_login,
web_logout,
authenticate_via_callback,
authenticate_log_login,
log_login,
getUDSCookie,
)
from uds.core.managers.user_service import UserServiceManager
@ -111,14 +111,14 @@ def auth_callback_stage2(request: 'ExtendedHttpRequestWithUser', ticketId: str)
raise exceptions.auth.Redirect(result.url)
if result.user is None:
authenticate_log_login(request, authenticator, f'{params}', 'Invalid at auth callback')
log_login(request, authenticator, f'{params}', 'Invalid at auth callback')
raise exceptions.auth.InvalidUserException()
response = HttpResponseRedirect(reverse('page.index'))
web_login(request, response, result.user, '') # Password is unavailable in this case
authenticate_log_login(request, authenticator, result.user.name, 'Federated login')
log_login(request, authenticator, result.user.name, 'Federated login')
# If MFA is provided, we need to redirect to MFA page
request.authorized = True
@ -239,7 +239,7 @@ def ticket_auth(
web_login(request, None, usr, password)
# Log the login
authenticate_log_login(request, auth, username, 'Ticket authentication')
log_login(request, auth, username, 'Ticket authentication')
request.user = (
usr # Temporarily store this user as "authenticated" user, next requests will be done using session
@ -256,7 +256,9 @@ def ticket_auth(
# Check if servicePool is part of the ticket
if poolUuid:
# Request service, with transport = None so it is automatic
res = UserServiceManager().get_user_service_info(request.user, request.os, request.ip, poolUuid, None, False)
res = UserServiceManager().get_user_service_info(
request.user, request.os, request.ip, poolUuid, None, False
)
_, userService, _, transport, _ = res
transportInstance = transport.get_instance()

View File

@ -28,46 +28,39 @@
"""
Author: Adolfo Gómez, dkmaster at dkmon dot com
"""
import collections.abc
import datetime
import json
import logging
import random
import re
import time
import logging
import typing
import collections.abc
import random
import json
from django.http import HttpRequest, HttpResponse, HttpResponseRedirect, JsonResponse
from django.middleware import csrf
from django.shortcuts import render
from django.views.decorators.csrf import csrf_exempt
from django.http import HttpRequest, HttpResponse, JsonResponse, HttpResponseRedirect
from django.views.decorators.cache import never_cache
from django.urls import reverse
from django.utils.translation import gettext as _
from py import log
from uds.core.types.request import ExtendedHttpRequest
from django.views.decorators.cache import never_cache
from django.views.decorators.csrf import csrf_exempt
from uds.core.types.request import ExtendedHttpRequestWithUser
from uds import auths, models
from uds.core import exceptions, mfas, types, consts
from uds.core.auths import auth
from uds.core.util import state
from uds.core.util.config import GlobalConfig
from uds.core.managers.crypto import CryptoManager
from uds.core.managers.user_service import UserServiceManager
from uds.web.util import errors
from uds.core.types.request import ExtendedHttpRequest, ExtendedHttpRequestWithUser
from uds.core.util import config, state, storage
from uds.core.util.model import sql_stamp_seconds
from uds.web.forms.LoginForm import LoginForm
from uds.web.forms.MFAForm import MFAForm
from uds.web.util import configjs, errors
from uds.web.util.authentication import check_login
from uds.web.util.services import getServicesData
from uds.web.util import configjs
from uds.core import mfas, types, exceptions
from uds import auths, models
from uds.core.util.model import sql_stamp_seconds
logger = logging.getLogger(__name__)
CSRF_FIELD = 'csrfmiddlewaretoken'
MFA_COOKIE_NAME = 'mfa_status'
if typing.TYPE_CHECKING:
pass
@ -82,7 +75,7 @@ def index(request: HttpRequest) -> HttpResponse:
response = render(
request=request,
template_name='uds/modern/index.html',
context={'csrf_field': CSRF_FIELD, 'csrf_token': csrf_token},
context={'csrf_field': consts.auth.CSRF_FIELD, 'csrf_token': csrf_token},
)
# Ensure UDS cookie is present
@ -153,7 +146,7 @@ def login(request: ExtendedHttpRequest, tag: typing.Optional[str] = None) -> Htt
@never_cache
@auth.web_login_required(admin=False)
def logout(request: ExtendedHttpRequestWithUser) -> HttpResponse:
auth.auth_log_logout(request)
auth.log_logout(request)
request.session['restricted'] = False # Remove restricted
request.authorized = False
logoutResponse = request.user.logout(request)
@ -169,11 +162,11 @@ def js(request: ExtendedHttpRequest) -> HttpResponse:
@never_cache
@auth.deny_non_authenticated # web_login_required not used here because this is not a web page, but js
def servicesData(request: ExtendedHttpRequestWithUser) -> HttpResponse:
def services_data_json(request: ExtendedHttpRequestWithUser) -> HttpResponse:
return JsonResponse(getServicesData(request))
# The MFA page does not needs CRF token, so we disable it
# The MFA page does not needs CSRF token, so we disable it
@csrf_exempt
def mfa(
request: ExtendedHttpRequest,
@ -182,26 +175,37 @@ def mfa(
logger.warning('MFA: No user or user is already authorized')
return HttpResponseRedirect(reverse('page.index')) # No user, no MFA
mfaProvider = typing.cast('None|models.MFA', request.user.manager.mfa)
if not mfaProvider:
store: 'storage.Storage' = storage.Storage('mfs')
mfa_provider = typing.cast('None|models.MFA', request.user.manager.mfa)
if not mfa_provider:
logger.warning('MFA: No MFA provider for user')
return HttpResponseRedirect(reverse('page.index'))
mfaUserId = mfas.MFA.get_user_id(request.user)
mfa_user_id = mfas.MFA.get_user_id(request.user)
# Try to get cookie anc check it
mfaCookie = request.COOKIES.get(MFA_COOKIE_NAME, None)
if mfaCookie == mfaUserId: # Cookie is valid, skip MFA setting authorization
logger.debug('MFA: Cookie is valid, skipping MFA')
request.authorized = True
return HttpResponseRedirect(reverse('page.index'))
mfa_cookie = request.COOKIES.get(consts.auth.MFA_COOKIE_NAME, None)
if mfa_cookie and mfa_provider.remember_device > 0:
stored_user_id: typing.Optional[str]
created: typing.Optional[datetime.datetime]
stored_user_id, created = store.get_unpickle(mfa_cookie) or (None, None)
if (
stored_user_id
and created
and created + datetime.timedelta(hours=mfa_provider.remember_device) > datetime.datetime.now()
):
# Cookie is valid, skip MFA setting authorization
logger.debug('MFA: Cookie is valid, skipping MFA')
request.authorized = True
return HttpResponseRedirect(reverse('page.index'))
# Obtain MFA data
authInstance = request.user.manager.get_instance()
mfaInstance = typing.cast('mfas.MFA', mfaProvider.get_instance())
auth_instance = request.user.manager.get_instance()
mfa_instance = typing.cast('mfas.MFA', mfa_provider.get_instance())
# Get validity duration
validity = mfaProvider.validity * 60
validity = mfa_provider.validity * 60
now = sql_stamp_seconds()
start_time = request.session.get('mfa_start_time', now)
@ -211,26 +215,26 @@ def mfa(
request.session.flush() # Clear session, and redirect to login
return HttpResponseRedirect(reverse('page.login'))
mfaIdentifier = authInstance.mfa_identifier(request.user.name)
label = mfaInstance.label()
mfa_identifier = auth_instance.mfa_identifier(request.user.name)
label = mfa_instance.label()
if not mfaIdentifier:
emtpyIdentifiedAllowed = mfaInstance.allow_login_without_identifier(request)
if not mfa_identifier:
allow_login_without_identifier = mfa_instance.allow_login_without_identifier(request)
# can be True, False or None
if emtpyIdentifiedAllowed is True:
if allow_login_without_identifier is True:
# Allow login
request.authorized = True
return HttpResponseRedirect(reverse('page.index'))
if emtpyIdentifiedAllowed is False:
if allow_login_without_identifier is False:
# Not allowed to login, redirect to login error page
logger.warning(
'MFA identifier not found for user %s on authenticator %s. It is required by MFA %s',
request.user.name,
request.user.manager.name,
mfaProvider.name,
mfa_provider.name,
)
return errors.errorView(request, errors.ACCESS_DENIED)
# None, the authenticator will decide what to do if mfaIdentifier is empty
# None, the authenticator will decide what to do if mfa_identifier is empty
tries = request.session.get('mfa_tries', 0)
if request.method == 'POST': # User has provided MFA code
@ -238,11 +242,11 @@ def mfa(
if form.is_valid():
code = form.cleaned_data['code']
try:
mfaInstance.validate(
mfa_instance.validate(
request,
mfaUserId,
mfa_user_id,
request.user.name,
mfaIdentifier,
mfa_identifier,
code,
validity=validity,
) # Will raise MFAError if code is not valid
@ -256,11 +260,17 @@ def mfa(
response = HttpResponseRedirect(reverse('page.index'))
# If mfaProvider requests to keep MFA code on client, create a mfacookie for this user
if mfaProvider.remember_device > 0 and form.cleaned_data['remember'] is True:
if mfa_provider.remember_device > 0 and form.cleaned_data['remember'] is True:
# Store also cookie locally, to check if remember_device is changed
mfa_cookie = CryptoManager().random_string(96)
store.put_pickle(
mfa_cookie,
(mfa_user_id, now),
)
response.set_cookie(
MFA_COOKIE_NAME,
mfaUserId,
max_age=mfaProvider.remember_device * 60 * 60,
consts.auth.MFA_COOKIE_NAME,
mfa_cookie,
max_age=mfa_provider.remember_device * 60 * 60,
)
return response
@ -268,7 +278,7 @@ def mfa(
logger.error('MFA error: %s', e)
tries += 1
request.session['mfa_tries'] = tries
if tries >= GlobalConfig.MAX_LOGIN_TRIES.getInt():
if tries >= config.GlobalConfig.MAX_LOGIN_TRIES.getInt():
# Clean session
request.session.flush()
# Too many tries, redirect to login error page
@ -280,11 +290,11 @@ def mfa(
# Make MFA send a code
request.session['mfa_tries'] = 0 # Reset tries
try:
result = mfaInstance.process(
result = mfa_instance.process(
request,
mfaUserId,
mfa_user_id,
request.user.name,
mfaIdentifier,
mfa_identifier,
validity=validity,
)
if result == mfas.MFA.RESULT.ALLOWED:
@ -302,15 +312,15 @@ def mfa(
# Compose a nice "XX years, XX months, XX days, XX hours, XX minutes" string from mfaProvider.remember_device
remember_device = ''
# Remember_device is in hours
if mfaProvider.remember_device > 0:
if mfa_provider.remember_device > 0:
# if more than a day, we show days only
if mfaProvider.remember_device >= 24:
remember_device = _('{} days').format(mfaProvider.remember_device // 24)
if mfa_provider.remember_device >= 24:
remember_device = _('{} days').format(mfa_provider.remember_device // 24)
else:
remember_device = _('{} hours').format(mfaProvider.remember_device)
remember_device = _('{} hours').format(mfa_provider.remember_device)
# Html from MFA provider
mfaHtml = mfaInstance.html(request, mfaUserId, request.user.name)
mfaHtml = mfa_instance.html(request, mfa_user_id, request.user.name)
# Redirect to index, but with MFA data
request.session['mfa'] = {

View File

@ -43,24 +43,24 @@ class StorageTest(UDSTestCase):
storage = Storage(UNICODE_CHARS)
storage.put(UNICODE_CHARS, b'chars')
storage.saveData('saveData', UNICODE_CHARS, UNICODE_CHARS)
storage.saveData('saveData2', UNICODE_CHARS_2, UNICODE_CHARS)
storage.saveData('saveData3', UNICODE_CHARS, 'attribute')
storage.saveData('saveData4', UNICODE_CHARS_2, 'attribute')
storage.save_to_db('saveData', UNICODE_CHARS, UNICODE_CHARS)
storage.save_to_db('saveData2', UNICODE_CHARS_2, UNICODE_CHARS)
storage.save_to_db('saveData3', UNICODE_CHARS, 'attribute')
storage.save_to_db('saveData4', UNICODE_CHARS_2, 'attribute')
storage.put(b'key', UNICODE_CHARS)
storage.put(UNICODE_CHARS_2, UNICODE_CHARS)
storage.put_pickle('pickle', VALUE_1)
self.assertEqual(storage.get(UNICODE_CHARS), u'chars') # Always returns unicod
self.assertEqual(storage.readData('saveData'), UNICODE_CHARS)
self.assertEqual(storage.readData('saveData2'), UNICODE_CHARS_2)
self.assertEqual(storage.read_from_db('saveData'), UNICODE_CHARS)
self.assertEqual(storage.read_from_db('saveData2'), UNICODE_CHARS_2)
self.assertEqual(storage.get(b'key'), UNICODE_CHARS)
self.assertEqual(storage.get(UNICODE_CHARS_2), UNICODE_CHARS)
self.assertEqual(storage.getPickle('pickle'), VALUE_1)
self.assertEqual(storage.get_unpickle('pickle'), VALUE_1)
self.assertEqual(len(list(storage.locateByAttr1(UNICODE_CHARS))), 2)
self.assertEqual(len(list(storage.locateByAttr1('attribute'))), 2)
self.assertEqual(len(list(storage.search_by_attr1(UNICODE_CHARS))), 2)
self.assertEqual(len(list(storage.search_by_attr1('attribute'))), 2)
storage.remove(UNICODE_CHARS)
storage.remove(b'key')
@ -68,4 +68,4 @@ class StorageTest(UDSTestCase):
self.assertIsNone(storage.get(UNICODE_CHARS))
self.assertIsNone(storage.get(b'key'))
self.assertIsNone(storage.getPickle('pickle'))
self.assertIsNone(storage.get_unpickle('pickle'))