mirror of
https://github.com/dkmstr/openuds.git
synced 2024-12-25 23:21:41 +03:00
Improved mfa code storage on browser and related security
This commit is contained in:
parent
1715e1a7a1
commit
bf3d36c901
@ -41,7 +41,7 @@ import dns.reversename
|
||||
from django.utils.translation import gettext_noop as _
|
||||
|
||||
from uds.core import auths, types, exceptions, consts
|
||||
from uds.core.auths.auth import authenticate_log_login
|
||||
from uds.core.auths.auth import log_login
|
||||
from uds.core.managers.crypto import CryptoManager
|
||||
from uds.core.ui import gui
|
||||
from uds.core.util.state import State
|
||||
@ -154,7 +154,7 @@ class InternalDBAuth(auths.Authenticator):
|
||||
try:
|
||||
user: 'models.User' = dbAuth.users.get(name=username, state=State.ACTIVE)
|
||||
except Exception:
|
||||
authenticate_log_login(request, self.db_obj(), username, 'Invalid user')
|
||||
log_login(request, self.db_obj(), username, 'Invalid user')
|
||||
return types.auth.FAILED_AUTH
|
||||
|
||||
if user.parent: # Direct auth not allowed for "derived" users
|
||||
@ -165,7 +165,7 @@ class InternalDBAuth(auths.Authenticator):
|
||||
groupsManager.validate([g.name for g in user.groups.all()])
|
||||
return types.auth.SUCCESS_AUTH
|
||||
|
||||
authenticate_log_login(request, self.db_obj(), username, 'Invalid password')
|
||||
log_login(request, self.db_obj(), username, 'Invalid password')
|
||||
return types.auth.FAILED_AUTH
|
||||
|
||||
def get_groups(self, username: str, groupsManager: 'auths.GroupsManager'):
|
||||
|
@ -226,7 +226,7 @@ class OAuth2Authenticator(auths.Authenticator):
|
||||
tab=_('Attributes'),
|
||||
)
|
||||
|
||||
def _getPublicKeys(self) -> list[typing.Any]: # In fact, any of the PublicKey types
|
||||
def _get_public_keys(self) -> list[typing.Any]: # In fact, any of the PublicKey types
|
||||
# Get certificates in self.publicKey.value, encoded as PEM
|
||||
# Return a list of certificates in DER format
|
||||
if self.publicKey.value.strip() == '':
|
||||
@ -234,7 +234,7 @@ class OAuth2Authenticator(auths.Authenticator):
|
||||
|
||||
return [cert.public_key() for cert in fields.get_vertificates_from_field(self.publicKey)]
|
||||
|
||||
def _codeVerifierAndChallenge(self) -> tuple[str, str]:
|
||||
def _code_verifier_and_challenge(self) -> tuple[str, str]:
|
||||
"""Generate a code verifier and a code challenge for PKCE
|
||||
|
||||
Returns:
|
||||
@ -249,7 +249,7 @@ class OAuth2Authenticator(auths.Authenticator):
|
||||
|
||||
return codeVerifier, codeChallenge
|
||||
|
||||
def _getResponseTypeString(self) -> str:
|
||||
def _get_response_type_string(self) -> str:
|
||||
match self.responseType.value:
|
||||
case 'code':
|
||||
return 'code'
|
||||
@ -264,14 +264,14 @@ class OAuth2Authenticator(auths.Authenticator):
|
||||
case _:
|
||||
raise Exception('Invalid response type')
|
||||
|
||||
def _getLoginURL(self, request: 'HttpRequest') -> str:
|
||||
def _get_login_url(self, request: 'HttpRequest') -> str:
|
||||
"""
|
||||
:type request: django.http.request.HttpRequest
|
||||
"""
|
||||
state: str = secrets.token_urlsafe(STATE_LENGTH)
|
||||
|
||||
param_dict = {
|
||||
'response_type': self._getResponseTypeString(),
|
||||
'response_type': self._get_response_type_string(),
|
||||
'client_id': self.clientId.value,
|
||||
'redirect_uri': self.redirectionEndpoint.value,
|
||||
'scope': self.scope.value.replace(',', ' '),
|
||||
@ -296,7 +296,7 @@ class OAuth2Authenticator(auths.Authenticator):
|
||||
|
||||
case 'pkce':
|
||||
# PKCE flow
|
||||
codeVerifier, codeChallenge = self._codeVerifierAndChallenge()
|
||||
codeVerifier, codeChallenge = self._code_verifier_and_challenge()
|
||||
param_dict['code_challenge'] = codeChallenge
|
||||
param_dict['code_challenge_method'] = 'S256'
|
||||
self.cache.put(state, codeVerifier, 3600)
|
||||
@ -315,7 +315,7 @@ class OAuth2Authenticator(auths.Authenticator):
|
||||
|
||||
return self.authorizationEndpoint.value + '?' + params
|
||||
|
||||
def _requestToken(self, code: str, code_verifier: typing.Optional[str] = None) -> TokenInfo:
|
||||
def _request_token(self, code: str, code_verifier: typing.Optional[str] = None) -> TokenInfo:
|
||||
"""Request a token from the token endpoint using the code received from the authorization endpoint
|
||||
|
||||
Args:
|
||||
@ -342,7 +342,7 @@ class OAuth2Authenticator(auths.Authenticator):
|
||||
|
||||
return TokenInfo.from_dict(req.json())
|
||||
|
||||
def _requestInfo(self, token: 'TokenInfo') -> dict[str, typing.Any]:
|
||||
def _request_info(self, token: 'TokenInfo') -> dict[str, typing.Any]:
|
||||
"""Request user info from the info endpoint using the token received from the token endpoint
|
||||
|
||||
If the token endpoint returns the user info, this method will not be used
|
||||
@ -374,7 +374,7 @@ class OAuth2Authenticator(auths.Authenticator):
|
||||
userInfo = req.json()
|
||||
return userInfo
|
||||
|
||||
def _processToken(
|
||||
def _process_token(
|
||||
self, userInfo: collections.abc.Mapping[str, typing.Any], gm: 'auths.GroupsManager'
|
||||
) -> types.auth.AuthenticationResult:
|
||||
# After this point, we don't mind about the token, we only need to authenticate user
|
||||
@ -401,7 +401,7 @@ class OAuth2Authenticator(auths.Authenticator):
|
||||
# and if we are here, the user is authenticated, so we can return SUCCESS_AUTH
|
||||
return types.auth.AuthenticationResult(types.auth.AuthenticationState.SUCCESS, username=username)
|
||||
|
||||
def _processTokenOpenId(
|
||||
def _process_token_open_id(
|
||||
self, token_id: str, nonce: str, gm: 'auths.GroupsManager'
|
||||
) -> types.auth.AuthenticationResult:
|
||||
# Get token headers, to extract algorithm
|
||||
@ -410,7 +410,7 @@ class OAuth2Authenticator(auths.Authenticator):
|
||||
|
||||
# We may have multiple public keys, try them all
|
||||
# (We should only have one, but just in case)
|
||||
for key in self._getPublicKeys():
|
||||
for key in self._get_public_keys():
|
||||
logger.debug('Key = %s', key)
|
||||
try:
|
||||
payload = jwt.decode(token, key=key, audience=self.clientId.value, algorithms=[info.get('alg', 'RSA256')]) # type: ignore
|
||||
@ -423,7 +423,7 @@ class OAuth2Authenticator(auths.Authenticator):
|
||||
# All is fine, get user & look for groups
|
||||
|
||||
# Process attributes from payload
|
||||
return self._processToken(payload, gm)
|
||||
return self._process_token(payload, gm)
|
||||
except (jwt.InvalidTokenError, IndexError):
|
||||
# logger.debug('Data was invalid: %s', e)
|
||||
pass
|
||||
@ -477,13 +477,13 @@ class OAuth2Authenticator(auths.Authenticator):
|
||||
) -> types.auth.AuthenticationResult:
|
||||
match self.responseType.value:
|
||||
case 'code' | 'pkce':
|
||||
return self.authCallbackCode(parameters, gm, request)
|
||||
return self.auth_callback_code(parameters, gm, request)
|
||||
case 'token':
|
||||
return self.authCallbackToken(parameters, gm, request)
|
||||
return self.auth_callback_token(parameters, gm, request)
|
||||
case 'openid+code':
|
||||
return self.authCallbackOpenIdCode(parameters, gm, request)
|
||||
return self.auth_callback_openid_code(parameters, gm, request)
|
||||
case 'openid+token_id':
|
||||
return self.authCallbackOpenIdIdToken(parameters, gm, request)
|
||||
return self.authcallback_openid_id_token(parameters, gm, request)
|
||||
case _:
|
||||
raise Exception('Invalid response type')
|
||||
return auths.SUCCESS_AUTH
|
||||
@ -499,21 +499,21 @@ class OAuth2Authenticator(auths.Authenticator):
|
||||
"""
|
||||
We will here compose the azure request and send it via http-redirect
|
||||
"""
|
||||
return f'window.location="{self._getLoginURL(request)}";'
|
||||
return f'window.location="{self._get_login_url(request)}";'
|
||||
|
||||
def get_groups(self, username: str, groupsManager: 'auths.GroupsManager'):
|
||||
data = self.storage.getPickle(username)
|
||||
data = self.storage.get_unpickle(username)
|
||||
if not data:
|
||||
return
|
||||
groupsManager.validate(data[1])
|
||||
|
||||
def get_real_name(self, username: str) -> str:
|
||||
data = self.storage.getPickle(username)
|
||||
data = self.storage.get_unpickle(username)
|
||||
if not data:
|
||||
return username
|
||||
return data[0]
|
||||
|
||||
def authCallbackCode(
|
||||
def auth_callback_code(
|
||||
self,
|
||||
parameters: 'types.auth.AuthCallbackParams',
|
||||
gm: 'auths.GroupsManager',
|
||||
@ -539,10 +539,10 @@ class OAuth2Authenticator(auths.Authenticator):
|
||||
if code_verifier == 'none':
|
||||
code_verifier = None
|
||||
|
||||
token = self._requestToken(code, code_verifier)
|
||||
return self._processToken(self._requestInfo(token), gm)
|
||||
token = self._request_token(code, code_verifier)
|
||||
return self._process_token(self._request_info(token), gm)
|
||||
|
||||
def authCallbackToken(
|
||||
def auth_callback_token(
|
||||
self,
|
||||
parameters: 'types.auth.AuthCallbackParams',
|
||||
gm: 'auths.GroupsManager',
|
||||
@ -568,9 +568,9 @@ class OAuth2Authenticator(auths.Authenticator):
|
||||
info={},
|
||||
id_token=None,
|
||||
)
|
||||
return self._processToken(self._requestInfo(token), gm)
|
||||
return self._process_token(self._request_info(token), gm)
|
||||
|
||||
def authCallbackOpenIdCode(
|
||||
def auth_callback_openid_code(
|
||||
self,
|
||||
parameters: 'types.auth.AuthCallbackParams',
|
||||
gm: 'auths.GroupsManager',
|
||||
@ -592,15 +592,15 @@ class OAuth2Authenticator(auths.Authenticator):
|
||||
return types.auth.FAILED_AUTH
|
||||
|
||||
# Get the token, token_type, expires
|
||||
token = self._requestToken(code)
|
||||
token = self._request_token(code)
|
||||
|
||||
if not token.id_token:
|
||||
logger.error('No id_token received on OAuth2 callback')
|
||||
return types.auth.FAILED_AUTH
|
||||
|
||||
return self._processTokenOpenId(token.id_token, nonce, gm)
|
||||
return self._process_token_open_id(token.id_token, nonce, gm)
|
||||
|
||||
def authCallbackOpenIdIdToken(
|
||||
def authcallback_openid_id_token(
|
||||
self,
|
||||
parameters: 'types.auth.AuthCallbackParams',
|
||||
gm: 'auths.GroupsManager',
|
||||
@ -621,4 +621,4 @@ class OAuth2Authenticator(auths.Authenticator):
|
||||
logger.error('Invalid id_token received on OAuth2 callback')
|
||||
return types.auth.FAILED_AUTH
|
||||
|
||||
return self._processTokenOpenId(id_token, nonce, gm)
|
||||
return self._process_token_open_id(id_token, nonce, gm)
|
||||
|
@ -37,7 +37,7 @@ import collections.abc
|
||||
from django.utils.translation import gettext_noop as _
|
||||
|
||||
from uds.core import auths, types, consts
|
||||
from uds.core.auths.auth import authenticate_log_login
|
||||
from uds.core.auths.auth import log_login
|
||||
from uds.core.managers.crypto import CryptoManager
|
||||
from uds.core.ui import gui
|
||||
|
||||
@ -125,7 +125,7 @@ class RadiusAuth(auths.Authenticator):
|
||||
def initialize(self, values: typing.Optional[dict[str, typing.Any]]) -> None:
|
||||
pass
|
||||
|
||||
def radiusClient(self) -> client.RadiusClient:
|
||||
def radius_client(self) -> client.RadiusClient:
|
||||
"""Return a new radius client ."""
|
||||
return client.RadiusClient(
|
||||
self.server.value,
|
||||
@ -135,11 +135,11 @@ class RadiusAuth(auths.Authenticator):
|
||||
appClassPrefix=self.appClassPrefix.value,
|
||||
)
|
||||
|
||||
def mfaStorageKey(self, username: str) -> str:
|
||||
def mfa_storage_key(self, username: str) -> str:
|
||||
return 'mfa_' + str(self.db_obj().uuid) + username
|
||||
|
||||
def mfa_identifier(self, username: str) -> str:
|
||||
return self.storage.getPickle(self.mfaStorageKey(username)) or ''
|
||||
return self.storage.get_unpickle(self.mfa_storage_key(username)) or ''
|
||||
|
||||
def authenticate(
|
||||
self,
|
||||
@ -149,7 +149,7 @@ class RadiusAuth(auths.Authenticator):
|
||||
request: 'ExtendedHttpRequest',
|
||||
) -> types.auth.AuthenticationResult:
|
||||
try:
|
||||
connection = self.radiusClient()
|
||||
connection = self.radius_client()
|
||||
groups, mfaCode, state = connection.authenticate(username=username, password=credentials, mfaField=self.mfaAttr.value.strip())
|
||||
# If state, store in session
|
||||
if state:
|
||||
@ -157,12 +157,12 @@ class RadiusAuth(auths.Authenticator):
|
||||
# store the user mfa attribute if it is set
|
||||
if mfaCode:
|
||||
self.storage.put_pickle(
|
||||
self.mfaStorageKey(username),
|
||||
self.mfa_storage_key(username),
|
||||
mfaCode,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
authenticate_log_login(
|
||||
log_login(
|
||||
request,
|
||||
self.db_obj(),
|
||||
username,
|
||||
@ -200,15 +200,15 @@ class RadiusAuth(auths.Authenticator):
|
||||
"""Test the connection to the server ."""
|
||||
try:
|
||||
auth = RadiusAuth(None, env, data) # type: ignore
|
||||
return auth.testConnection()
|
||||
return auth.test_connection()
|
||||
except Exception as e:
|
||||
logger.error("Exception found testing Radius auth %s: %s", e.__class__, e)
|
||||
return [False, _('Error testing connection')]
|
||||
|
||||
def testConnection(self):
|
||||
def test_connection(self):
|
||||
"""Test connection to Radius Server"""
|
||||
try:
|
||||
connection = self.radiusClient()
|
||||
connection = self.radius_client()
|
||||
# Reply is not important...
|
||||
connection.authenticate(
|
||||
CryptoManager().random_string(10), CryptoManager().random_string(10)
|
||||
|
@ -40,7 +40,7 @@ import ldap
|
||||
from django.utils.translation import gettext_noop as _
|
||||
|
||||
from uds.core import auths, exceptions, types, consts
|
||||
from uds.core.auths.auth import authenticate_log_login
|
||||
from uds.core.auths.auth import log_login
|
||||
from uds.core.ui import gui
|
||||
from uds.core.util import ldaputil, auth as auth_utils
|
||||
|
||||
@ -275,7 +275,7 @@ class RegexLdap(auths.Authenticator):
|
||||
return 'mfa_' + self.db_obj().uuid + username
|
||||
|
||||
def mfa_identifier(self, username: str) -> str:
|
||||
return self.storage.getPickle(self.mfaStorageKey(username)) or ''
|
||||
return self.storage.get_unpickle(self.mfaStorageKey(username)) or ''
|
||||
|
||||
def dict_of_values(self) -> gui.ValuesDictType:
|
||||
return {
|
||||
@ -494,14 +494,14 @@ class RegexLdap(auths.Authenticator):
|
||||
usr = self.__getUser(username)
|
||||
|
||||
if usr is None:
|
||||
authenticate_log_login(request, self.db_obj(), username, 'Invalid user')
|
||||
log_login(request, self.db_obj(), username, 'Invalid user')
|
||||
return types.auth.FAILED_AUTH
|
||||
|
||||
try:
|
||||
# Let's see first if it credentials are fine
|
||||
self.__connectAs(usr['dn'], credentials) # Will raise an exception if it can't connect
|
||||
except Exception:
|
||||
authenticate_log_login(request, self.db_obj(), username, 'Invalid password')
|
||||
log_login(request, self.db_obj(), username, 'Invalid password')
|
||||
return types.auth.FAILED_AUTH
|
||||
|
||||
# store the user mfa attribute if it is set
|
||||
|
@ -567,7 +567,7 @@ class SAMLAuthenticator(auths.Authenticator):
|
||||
self.storage.remove(self.mfaStorageKey(username))
|
||||
|
||||
def mfa_identifier(self, username: str) -> str:
|
||||
return self.storage.getPickle(self.mfaStorageKey(username)) or ''
|
||||
return self.storage.get_unpickle(self.mfaStorageKey(username)) or ''
|
||||
|
||||
def logoutFromCallback(
|
||||
self,
|
||||
@ -725,13 +725,13 @@ class SAMLAuthenticator(auths.Authenticator):
|
||||
)
|
||||
|
||||
def get_groups(self, username: str, groupsManager: 'auths.GroupsManager'):
|
||||
data = self.storage.getPickle(username)
|
||||
data = self.storage.get_unpickle(username)
|
||||
if not data:
|
||||
return
|
||||
groupsManager.validate(data[1])
|
||||
|
||||
def get_real_name(self, username: str) -> str:
|
||||
data = self.storage.getPickle(username)
|
||||
data = self.storage.get_unpickle(username)
|
||||
if not data:
|
||||
return username
|
||||
return data[0]
|
||||
|
@ -38,7 +38,7 @@ import ldap.filter
|
||||
from django.utils.translation import gettext_noop as _
|
||||
|
||||
from uds.core import auths, types, consts, exceptions
|
||||
from uds.core.auths.auth import authenticate_log_login
|
||||
from uds.core.auths.auth import log_login
|
||||
from uds.core.ui import gui
|
||||
from uds.core.util import ldaputil
|
||||
|
||||
@ -316,7 +316,7 @@ class SimpleLDAPAuthenticator(auths.Authenticator):
|
||||
return 'mfa_' + str(self.db_obj().uuid) + username
|
||||
|
||||
def mfa_identifier(self, username: str) -> str:
|
||||
return self.storage.getPickle(self.mfaStorageKey(username)) or ''
|
||||
return self.storage.get_unpickle(self.mfaStorageKey(username)) or ''
|
||||
|
||||
def __connection(self):
|
||||
"""
|
||||
@ -447,14 +447,14 @@ class SimpleLDAPAuthenticator(auths.Authenticator):
|
||||
user = self.__getUser(username)
|
||||
|
||||
if user is None:
|
||||
authenticate_log_login(request, self.db_obj(), username, 'Invalid user')
|
||||
log_login(request, self.db_obj(), username, 'Invalid user')
|
||||
return types.auth.FAILED_AUTH
|
||||
|
||||
try:
|
||||
# Let's see first if it credentials are fine
|
||||
self.__connectAs(user['dn'], credentials) # Will raise an exception if it can't connect
|
||||
except Exception:
|
||||
authenticate_log_login(request, self.db_obj(), username, 'Invalid password')
|
||||
log_login(request, self.db_obj(), username, 'Invalid password')
|
||||
return types.auth.FAILED_AUTH
|
||||
|
||||
# store the user mfa attribute if it is set
|
||||
|
@ -513,7 +513,7 @@ def web_logout(
|
||||
return response
|
||||
|
||||
|
||||
def authenticate_log_login(
|
||||
def log_login(
|
||||
request: 'ExtendedHttpRequest',
|
||||
authenticator: models.Authenticator,
|
||||
userName: str,
|
||||
@ -558,7 +558,7 @@ def authenticate_log_login(
|
||||
logger.info('Root %s from %s where OS is %s', log_string, request.ip, request.os.os.name)
|
||||
|
||||
|
||||
def auth_log_logout(request: 'ExtendedHttpRequest') -> None:
|
||||
def log_logout(request: 'ExtendedHttpRequest') -> None:
|
||||
if request.user:
|
||||
if request.user.manager.id is not None:
|
||||
log.log(
|
||||
|
@ -39,8 +39,12 @@ HIDDEN: typing.Final[str] = 'h'
|
||||
DISABLED: typing.Final[str] = 'd'
|
||||
|
||||
# net_filter
|
||||
# Note: this are STANDARD values used on "default field" networks on RESP API
|
||||
# Named them for better reading, but cannot be changed, since they are used on RESP API
|
||||
# Note: this are STANDARD values used on "default field" networks on REST API
|
||||
# Named them for better reading, but cannot be changed, since they are used on REST API
|
||||
NO_FILTERING: typing.Final[str] = 'n'
|
||||
ALLOW: typing.Final[str] = 'a'
|
||||
DENY: typing.Final[str] = 'd'
|
||||
|
||||
# Cookie for mfa and csrf field
|
||||
MFA_COOKIE_NAME: typing.Final[str] = 'mfa_status'
|
||||
CSRF_FIELD: typing.Final[str] = 'csrfmiddlewaretoken'
|
||||
|
@ -315,3 +315,9 @@ class CryptoManager(metaclass=singleton.Singleton):
|
||||
return hashlib.sha3_256(
|
||||
(self.random_string(24, True) + datetime.datetime.now().strftime('%H%M%S%f')).encode()
|
||||
).hexdigest()
|
||||
|
||||
def sha(self, value: typing.Union[str, bytes]) -> str:
|
||||
if isinstance(value, str):
|
||||
value = value.encode()
|
||||
|
||||
return hashlib.sha3_256(value).hexdigest()
|
||||
|
@ -215,7 +215,7 @@ class MFA(Module):
|
||||
Internal method to get the data from storage
|
||||
"""
|
||||
storageKey = request.ip + userId
|
||||
return self.storage.getPickle(storageKey)
|
||||
return self.storage.get_unpickle(storageKey)
|
||||
|
||||
def _remove_data(self, request: 'ExtendedHttpRequest', userId: str) -> None:
|
||||
"""
|
||||
|
@ -443,7 +443,7 @@ class Service(Module):
|
||||
|
||||
def recover_id_info(self, id: str, delete: bool = False) -> typing.Any:
|
||||
# recovers the information
|
||||
value = self.storage.getPickle('__nfo_' + id)
|
||||
value = self.storage.get_unpickle('__nfo_' + id)
|
||||
if value and delete:
|
||||
self.storage.delete('__nfo_' + id)
|
||||
return value
|
||||
|
@ -65,28 +65,24 @@ class Cache:
|
||||
return serializer.deserialize(codecs.decode(value.encode(), 'base64'))
|
||||
|
||||
_serializer: typing.ClassVar[collections.abc.Callable[[typing.Any], str]] = _basic_serialize
|
||||
_deserializer: typing.ClassVar[
|
||||
collections.abc.Callable[[str], typing.Any]
|
||||
] = _basic_deserialize
|
||||
_deserializer: typing.ClassVar[collections.abc.Callable[[str], typing.Any]] = _basic_deserialize
|
||||
|
||||
def __init__(self, owner: typing.Union[str, bytes]):
|
||||
self._owner = typing.cast(str, owner.decode('utf-8') if isinstance(owner, bytes) else owner)
|
||||
self._bowner = self._owner.encode('utf8')
|
||||
|
||||
def __getKey(self, key: typing.Union[str, bytes]) -> str:
|
||||
def __get_key(self, key: typing.Union[str, bytes]) -> str:
|
||||
if isinstance(key, str):
|
||||
key = key.encode('utf8')
|
||||
return hash_key(self._bowner + key)
|
||||
|
||||
def get(
|
||||
self, skey: typing.Union[str, bytes], defValue: typing.Any = None
|
||||
) -> typing.Any:
|
||||
def get(self, skey: typing.Union[str, bytes], defValue: typing.Any = None) -> typing.Any:
|
||||
now = sql_datetime()
|
||||
# logger.debug('Requesting key "%s" for cache "%s"', skey, self._owner)
|
||||
try:
|
||||
key = self.__getKey(skey)
|
||||
key = self.__get_key(skey)
|
||||
# logger.debug('Key: %s', key)
|
||||
c: DBCache = DBCache.objects.get(pk=key) # @UndefinedVariable
|
||||
c: DBCache = DBCache.objects.get(owner=self._owner, pk=key)
|
||||
# If expired
|
||||
if now > c.created + datetime.timedelta(seconds=c.validity):
|
||||
return defValue
|
||||
@ -105,21 +101,22 @@ class Cache:
|
||||
Cache.misses += 1
|
||||
# logger.debug('key not found: %s', skey)
|
||||
return defValue
|
||||
#except OperationalError:
|
||||
# If database is not ready, just return default value
|
||||
# This is not a big issue, since cache is not critical
|
||||
# and probably will be generated by sqlite on high concurrency
|
||||
|
||||
# except OperationalError:
|
||||
# If database is not ready, just return default value
|
||||
# This is not a big issue, since cache is not critical
|
||||
# and probably will be generated by sqlite on high concurrency
|
||||
|
||||
# Cache.misses += 1
|
||||
# return defValue
|
||||
except Exception:
|
||||
import inspect
|
||||
|
||||
# Get caller
|
||||
error = 'Error Getting cache key from: '
|
||||
for caller in inspect.stack():
|
||||
error += f'{caller.filename}:{caller.lineno} -> '
|
||||
logger.error(error)
|
||||
|
||||
|
||||
# logger.exception('Error getting cache key: %s', skey)
|
||||
Cache.misses += 1
|
||||
return defValue
|
||||
@ -137,7 +134,7 @@ class Cache:
|
||||
"""
|
||||
# logger.debug('Removing key "%s" for uService "%s"' % (skey, self._owner))
|
||||
try:
|
||||
key = self.__getKey(skey)
|
||||
key = self.__get_key(skey)
|
||||
DBCache.objects.get(pk=key).delete() # @UndefinedVariable
|
||||
return True
|
||||
except DBCache.DoesNotExist: # @UndefinedVariable
|
||||
@ -163,22 +160,34 @@ class Cache:
|
||||
# logger.debug('Saving key "%s" for cache "%s"' % (skey, self._owner,))
|
||||
if validity is None:
|
||||
validity = consts.system.DEFAULT_CACHE_TIMEOUT
|
||||
key = self.__getKey(skey)
|
||||
key = self.__get_key(skey)
|
||||
strValue = Cache._serializer(value)
|
||||
now = sql_datetime()
|
||||
# Remove existing if any and create a new one
|
||||
with transaction.atomic():
|
||||
try:
|
||||
# Remove if existing
|
||||
DBCache.objects.filter(pk=key).delete()
|
||||
# And create a new one
|
||||
DBCache.objects.create(
|
||||
DBCache.objects.update_or_create(
|
||||
pk=key,
|
||||
owner=self._owner,
|
||||
key=key,
|
||||
value=strValue,
|
||||
created=now,
|
||||
validity=validity,
|
||||
) # @UndefinedVariable
|
||||
defaults={
|
||||
'owner': self._owner,
|
||||
'key': key,
|
||||
'value': strValue,
|
||||
'created': now,
|
||||
'validity': validity,
|
||||
},
|
||||
)
|
||||
|
||||
# # Remove if existing
|
||||
# DBCache.objects.filter(pk=key).delete()
|
||||
# # And create a new one
|
||||
# DBCache.objects.create(
|
||||
# owner=self._owner,
|
||||
# key=key,
|
||||
# value=strValue,
|
||||
# created=now,
|
||||
# validity=validity,
|
||||
# ) # @UndefinedVariable
|
||||
return # And return
|
||||
except Exception as e:
|
||||
logger.debug('Transaction in course, cannot store value: %s', e)
|
||||
@ -192,7 +201,7 @@ class Cache:
|
||||
def refresh(self, skey: typing.Union[str, bytes]) -> None:
|
||||
# logger.debug('Refreshing key "%s" for cache "%s"' % (skey, self._owner,))
|
||||
try:
|
||||
key = self.__getKey(skey)
|
||||
key = self.__get_key(skey)
|
||||
c = DBCache.objects.get(pk=key)
|
||||
c.created = sql_datetime()
|
||||
c.save()
|
||||
|
@ -46,7 +46,7 @@ logger = logging.getLogger(__name__)
|
||||
MARK = '_mgb_'
|
||||
|
||||
|
||||
def _calcKey(owner: bytes, key: bytes, extra: typing.Optional[bytes] = None) -> str:
|
||||
def _calculate_key(owner: bytes, key: bytes, extra: typing.Optional[bytes] = None) -> str:
|
||||
h = hashlib.md5(usedforsecurity=False)
|
||||
h.update(owner)
|
||||
h.update(key)
|
||||
@ -55,14 +55,14 @@ def _calcKey(owner: bytes, key: bytes, extra: typing.Optional[bytes] = None) ->
|
||||
return h.hexdigest()
|
||||
|
||||
|
||||
def _encodeValue(key: str, value: typing.Any, compat: bool = False) -> str:
|
||||
def _encode_value(key: str, value: typing.Any, compat: bool = False) -> str:
|
||||
if not compat:
|
||||
return base64.b64encode(pickle.dumps((MARK, key, value))).decode()
|
||||
# Compatibility save
|
||||
return base64.b64encode(pickle.dumps(value)).decode()
|
||||
|
||||
|
||||
def _decodeValue(dbk: str, value: typing.Optional[str]) -> tuple[str, typing.Any]:
|
||||
def _decode_value(dbk: str, value: typing.Optional[str]) -> tuple[str, typing.Any]:
|
||||
if value:
|
||||
try:
|
||||
v = pickle.loads(base64.b64decode(value.encode())) # nosec: This is e controled pickle loading
|
||||
@ -121,7 +121,7 @@ class StorageAsDict(MutableMapping):
|
||||
if key[0] == '#':
|
||||
# Compat with old db key
|
||||
return key[1:]
|
||||
return _calcKey(self._owner.encode(), key.encode())
|
||||
return _calculate_key(self._owner.encode(), key.encode())
|
||||
|
||||
def __getitem__(self, key: str) -> typing.Any:
|
||||
if not isinstance(key, str):
|
||||
@ -130,7 +130,11 @@ class StorageAsDict(MutableMapping):
|
||||
dbk = self._key(key)
|
||||
try:
|
||||
c: DBStorage = typing.cast(DBStorage, self._db.get(pk=dbk))
|
||||
return _decodeValue(dbk, c.data)[1] # Ignores original key
|
||||
if c.owner != self._owner: # Maybe a key collision,
|
||||
logger.error('Key collision detected for key %s', key)
|
||||
return None
|
||||
okey, value = _decode_value(dbk, c.data)
|
||||
return _decode_value(dbk, c.data)[1] # Ignores original key
|
||||
except DBStorage.DoesNotExist:
|
||||
return None
|
||||
|
||||
@ -139,7 +143,7 @@ class StorageAsDict(MutableMapping):
|
||||
raise TypeError(f'Key must be str type, {type(key)} found')
|
||||
|
||||
dbk = self._key(key)
|
||||
data = _encodeValue(key, value, self._compat)
|
||||
data = _encode_value(key, value, self._compat)
|
||||
# ignores return value, we don't care if it was created or updated
|
||||
DBStorage.objects.update_or_create(
|
||||
key=dbk, defaults={'data': data, 'attr1': self._group, 'owner': self._owner}
|
||||
@ -153,7 +157,7 @@ class StorageAsDict(MutableMapping):
|
||||
"""
|
||||
Iterates through keys
|
||||
"""
|
||||
return iter(_decodeValue(i.key, i.data)[0] for i in self._filtered)
|
||||
return iter(_decode_value(i.key, i.data)[0] for i in self._filtered)
|
||||
|
||||
def __contains__(self, key: object) -> bool:
|
||||
if isinstance(key, str):
|
||||
@ -165,10 +169,10 @@ class StorageAsDict(MutableMapping):
|
||||
|
||||
# Optimized methods, avoid re-reading from DB
|
||||
def items(self):
|
||||
return iter(_decodeValue(i.key, i.data) for i in self._filtered)
|
||||
return iter(_decode_value(i.key, i.data) for i in self._filtered)
|
||||
|
||||
def values(self):
|
||||
return iter(_decodeValue(i.key, i.data)[1] for i in self._filtered)
|
||||
return iter(_decode_value(i.key, i.data)[1] for i in self._filtered)
|
||||
|
||||
def get(self, key: str, default: typing.Any = None) -> typing.Any:
|
||||
return self[key] or default
|
||||
@ -191,6 +195,11 @@ class StorageAccess:
|
||||
Allows the access to the storage as a dict, with atomic transaction if requested
|
||||
"""
|
||||
|
||||
owner: str
|
||||
group: typing.Optional[str]
|
||||
atomic: bool
|
||||
compat: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
owner: str,
|
||||
@ -223,13 +232,13 @@ class Storage:
|
||||
_bownwer: bytes
|
||||
|
||||
def __init__(self, owner: typing.Union[str, bytes]):
|
||||
self._owner = owner.decode('utf-8') if isinstance(owner, bytes) else owner
|
||||
self._owner = typing.cast(str, owner.decode('utf-8') if isinstance(owner, bytes) else owner)
|
||||
self._bowner = self._owner.encode('utf8')
|
||||
|
||||
def getKey(self, key: typing.Union[str, bytes]) -> str:
|
||||
return _calcKey(self._bowner, key.encode('utf8') if isinstance(key, str) else key)
|
||||
def get_key(self, key: typing.Union[str, bytes]) -> str:
|
||||
return _calculate_key(self._bowner, key.encode('utf8') if isinstance(key, str) else key)
|
||||
|
||||
def saveData(
|
||||
def save_to_db(
|
||||
self,
|
||||
skey: typing.Union[str, bytes],
|
||||
data: typing.Any,
|
||||
@ -240,21 +249,21 @@ class Storage:
|
||||
self.remove(skey)
|
||||
return
|
||||
|
||||
key = self.getKey(skey)
|
||||
key = self.get_key(skey)
|
||||
if isinstance(data, str):
|
||||
data = data.encode('utf-8')
|
||||
dataStr = codecs.encode(data, 'base64').decode()
|
||||
data_string = codecs.encode(data, 'base64').decode()
|
||||
attr1 = attr1 or ''
|
||||
try:
|
||||
DBStorage.objects.create(owner=self._owner, key=key, data=dataStr, attr1=attr1)
|
||||
DBStorage.objects.create(owner=self._owner, key=key, data=data_string, attr1=attr1)
|
||||
except Exception:
|
||||
with transaction.atomic():
|
||||
DBStorage.objects.filter(key=key).select_for_update().update(
|
||||
owner=self._owner, data=dataStr, attr1=attr1
|
||||
owner=self._owner, data=data_string, attr1=attr1
|
||||
) # @UndefinedVariable
|
||||
|
||||
def put(self, skey: typing.Union[str, bytes], data: typing.Any) -> None:
|
||||
return self.saveData(skey, data)
|
||||
return self.save_to_db(skey, data)
|
||||
|
||||
def put_pickle(
|
||||
self,
|
||||
@ -262,23 +271,23 @@ class Storage:
|
||||
data: typing.Any,
|
||||
attr1: typing.Optional[str] = None,
|
||||
) -> None:
|
||||
return self.saveData(
|
||||
return self.save_to_db(
|
||||
skey, pickle.dumps(data), attr1
|
||||
) # Protocol 2 is compatible with python 2.7. This will be unnecesary when fully migrated
|
||||
|
||||
def updateData(
|
||||
def update_to_db(
|
||||
self,
|
||||
skey: typing.Union[str, bytes],
|
||||
data: typing.Any,
|
||||
attr1: typing.Optional[str] = None,
|
||||
) -> None:
|
||||
self.saveData(skey, data, attr1)
|
||||
self.save_to_db(skey, data, attr1)
|
||||
|
||||
def readData(
|
||||
def read_from_db(
|
||||
self, skey: typing.Union[str, bytes], fromPickle: bool = False
|
||||
) -> typing.Optional[typing.Union[str, bytes]]:
|
||||
try:
|
||||
key = self.getKey(skey)
|
||||
key = self.get_key(skey)
|
||||
c: DBStorage = DBStorage.objects.get(pk=key) # @UndefinedVariable
|
||||
val = codecs.decode(c.data.encode(), 'base64')
|
||||
|
||||
@ -293,15 +302,15 @@ class Storage:
|
||||
return None
|
||||
|
||||
def get(self, skey: typing.Union[str, bytes]) -> typing.Optional[typing.Union[str, bytes]]:
|
||||
return self.readData(skey)
|
||||
return self.read_from_db(skey)
|
||||
|
||||
def getPickle(self, skey: typing.Union[str, bytes]) -> typing.Any:
|
||||
v = self.readData(skey, True)
|
||||
def get_unpickle(self, skey: typing.Union[str, bytes]) -> typing.Any:
|
||||
v = self.read_from_db(skey, True)
|
||||
if v:
|
||||
return pickle.loads(typing.cast(bytes, v)) # nosec: This is e controled pickle loading
|
||||
return None
|
||||
|
||||
def getPickleByAttr1(self, attr1: str, forUpdate: bool = False):
|
||||
def get_unpickle_by_attr1(self, attr1: str, forUpdate: bool = False):
|
||||
try:
|
||||
query = DBStorage.objects.filter(owner=self._owner, attr1=attr1)
|
||||
if forUpdate:
|
||||
@ -312,11 +321,16 @@ class Storage:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def remove(self, skey: typing.Union[collections.abc.Iterable[typing.Union[str, bytes]], str, bytes]) -> None:
|
||||
keys: collections.abc.Iterable[typing.Union[str, bytes]] = [skey] if isinstance(skey, (str, bytes)) else skey
|
||||
def remove(
|
||||
self, skey: typing.Union[collections.abc.Iterable[typing.Union[str, bytes]], str, bytes]
|
||||
) -> None:
|
||||
keys: collections.abc.Iterable[typing.Union[str, bytes]] = typing.cast(
|
||||
collections.abc.Iterable[typing.Union[str, bytes]],
|
||||
[skey] if isinstance(skey, (str, bytes)) else skey,
|
||||
)
|
||||
try:
|
||||
# Process several keys at once
|
||||
DBStorage.objects.filter(key__in=[self.getKey(k) for k in keys]).delete()
|
||||
DBStorage.objects.filter(key__in=[self.get_key(k) for k in keys]).delete()
|
||||
except Exception: # nosec: Not interested in processing exceptions, just ignores it
|
||||
pass
|
||||
|
||||
@ -342,7 +356,9 @@ class Storage:
|
||||
) -> StorageAccess:
|
||||
return StorageAccess(self._owner, group=group, atomic=atomic, compat=compat)
|
||||
|
||||
def locateByAttr1(self, attr1: typing.Union[collections.abc.Iterable[str], str]) -> collections.abc.Iterable[bytes]:
|
||||
def search_by_attr1(
|
||||
self, attr1: typing.Union[collections.abc.Iterable[str], str]
|
||||
) -> collections.abc.Iterable[bytes]:
|
||||
if isinstance(attr1, str):
|
||||
query = DBStorage.objects.filter(owner=self._owner, attr1=attr1) # @UndefinedVariable
|
||||
else:
|
||||
@ -365,7 +381,7 @@ class Storage:
|
||||
for v in query: # @UndefinedVariable
|
||||
yield (v.key, codecs.decode(v.data.encode(), 'base64'), v.attr1)
|
||||
|
||||
def filterPickle(
|
||||
def filter_unpickle(
|
||||
self, attr1: typing.Optional[str] = None, forUpdate: bool = False
|
||||
) -> collections.abc.Iterable[tuple[str, typing.Any, 'str|None']]:
|
||||
for v in self.filter(attr1, forUpdate):
|
||||
|
@ -128,7 +128,7 @@ class TOTP_MFA(mfas.MFA):
|
||||
# Get data from storage related to this user
|
||||
# Data contains the secret and if the user has already logged in already some time
|
||||
# so we show the QR code only once
|
||||
data: typing.Optional[tuple[str, bool]] = self.storage.getPickle(userId)
|
||||
data: typing.Optional[tuple[str, bool]] = self.storage.get_unpickle(userId)
|
||||
if data is None:
|
||||
data = (pyotp.random_base32(), False)
|
||||
self._saveUserData(userId, data)
|
||||
|
@ -36,7 +36,7 @@ import collections.abc
|
||||
|
||||
from django.db import models
|
||||
|
||||
from uds.core import transports, types
|
||||
from uds.core import transports, types, consts
|
||||
|
||||
from uds.core.util import net
|
||||
|
||||
@ -58,13 +58,8 @@ class Transport(ManagedObjectModel, TaggingMixin):
|
||||
Sample of transports are RDP, Spice, Web file uploader, etc...
|
||||
"""
|
||||
|
||||
# Constants for net_filter
|
||||
NO_FILTERING = 'n'
|
||||
ALLOW = 'a'
|
||||
DENY = 'd'
|
||||
|
||||
priority = models.IntegerField(default=0, db_index=True)
|
||||
net_filtering = models.CharField(max_length=1, default=NO_FILTERING, db_index=True)
|
||||
net_filtering = models.CharField(max_length=1, default=consts.auth.NO_FILTERING, db_index=True)
|
||||
# We store allowed oss as a comma-separated list
|
||||
allowed_oss = models.CharField(max_length=255, default='')
|
||||
# Label, to group transports on meta pools
|
||||
@ -130,14 +125,14 @@ class Transport(ManagedObjectModel, TaggingMixin):
|
||||
# Avoid circular import
|
||||
from uds.models import Network # pylint: disable=import-outside-toplevel
|
||||
|
||||
if self.net_filtering == Transport.NO_FILTERING:
|
||||
if self.net_filtering == consts.auth.NO_FILTERING:
|
||||
return True
|
||||
ip, version = net.ip_to_long(ipStr)
|
||||
# Allow
|
||||
exists = self.networks.filter(
|
||||
start__lte=Network.hexlify(ip), end__gte=Network.hexlify(ip), version=version
|
||||
).exists()
|
||||
if self.net_filtering == Transport.ALLOW:
|
||||
if self.net_filtering == consts.auth.ALLOW:
|
||||
return exists
|
||||
# Deny, must not be in any network
|
||||
return not exists
|
||||
|
@ -140,7 +140,7 @@ class TelegramNotifier(messaging.Notifier):
|
||||
telegramMsg = f'{group} - {identificator} - {str(level)}: {message}'
|
||||
logger.debug('Sending telegram message: %s', telegramMsg)
|
||||
# load chatIds
|
||||
chatIds = self.storage.getPickle('chatIds') or []
|
||||
chatIds = self.storage.get_unpickle('chatIds') or []
|
||||
t = telegram.Telegram(self.accessToken.value, self.botname.value)
|
||||
for chatId in chatIds:
|
||||
with ignoreExceptions():
|
||||
@ -151,7 +151,7 @@ class TelegramNotifier(messaging.Notifier):
|
||||
def subscribeUser(self, chatId: int) -> None:
|
||||
# we do not expect to have a lot of users, so we will use a simple storage
|
||||
# that holds a list of chatIds
|
||||
chatIds = self.storage.getPickle('chatIds') or []
|
||||
chatIds = self.storage.get_unpickle('chatIds') or []
|
||||
if chatId not in chatIds:
|
||||
chatIds.append(chatId)
|
||||
self.storage.put_pickle('chatIds', chatIds)
|
||||
@ -160,7 +160,7 @@ class TelegramNotifier(messaging.Notifier):
|
||||
def unsubscriteUser(self, chatId: int) -> None:
|
||||
# we do not expect to have a lot of users, so we will use a simple storage
|
||||
# that holds a list of chatIds
|
||||
chatIds = self.storage.getPickle('chatIds') or []
|
||||
chatIds = self.storage.get_unpickle('chatIds') or []
|
||||
if chatId in chatIds:
|
||||
chatIds.remove(chatId)
|
||||
self.storage.put_pickle('chatIds', chatIds)
|
||||
@ -170,7 +170,7 @@ class TelegramNotifier(messaging.Notifier):
|
||||
if not self.accessToken.value.strip():
|
||||
return # no access token, no messages
|
||||
# Time of last retrieve
|
||||
lastCheck: typing.Optional[datetime.datetime] = self.storage.getPickle('lastCheck')
|
||||
lastCheck: typing.Optional[datetime.datetime] = self.storage.get_unpickle('lastCheck')
|
||||
now = sql_datetime()
|
||||
|
||||
# If last check is not set, we will set it to now
|
||||
@ -185,7 +185,7 @@ class TelegramNotifier(messaging.Notifier):
|
||||
# Update last check
|
||||
self.storage.put_pickle('lastCheck', now)
|
||||
|
||||
lastOffset = self.storage.getPickle('lastOffset') or 0
|
||||
lastOffset = self.storage.get_unpickle('lastOffset') or 0
|
||||
t = telegram.Telegram(self.accessToken.value, last_offset=lastOffset)
|
||||
with ignoreExceptions(): # In case getUpdates fails, ignore it
|
||||
for update in t.get_updates():
|
||||
|
@ -71,10 +71,10 @@ class OVirtDeferredRemoval(jobs.Job):
|
||||
if state in ('up', 'powering_up', 'suspended'):
|
||||
providerInstance.stopMachine(vmId)
|
||||
elif state != 'unknown': # Machine exists, remove it later
|
||||
providerInstance.storage.saveData('tr' + vmId, vmId, attr1='tRm')
|
||||
providerInstance.storage.save_to_db('tr' + vmId, vmId, attr1='tRm')
|
||||
|
||||
except Exception as e:
|
||||
providerInstance.storage.saveData('tr' + vmId, vmId, attr1='tRm')
|
||||
providerInstance.storage.save_to_db('tr' + vmId, vmId, attr1='tRm')
|
||||
logger.info(
|
||||
'Machine %s could not be removed right now, queued for later: %s',
|
||||
vmId,
|
||||
|
@ -169,7 +169,7 @@ class IPMachinesService(IPServiceBase):
|
||||
'{}~{}'.format(str(ip).strip(), i) for i, ip in enumerate(values['ipList']) if str(ip).strip()
|
||||
] # Allow duplicates right now
|
||||
# Current stored data, if it exists
|
||||
d = self.storage.readData('ips')
|
||||
d = self.storage.read_from_db('ips')
|
||||
old_ips = pickle.loads(d) if d and isinstance(d, bytes) else [] # nosec: pickle is safe here
|
||||
# dissapeared ones
|
||||
dissapeared = set(IPServiceBase.getIp(i.split('~')[0]) for i in old_ips) - set(
|
||||
@ -202,7 +202,7 @@ class IPMachinesService(IPServiceBase):
|
||||
}
|
||||
|
||||
def marshal(self) -> bytes:
|
||||
self.storage.saveData('ips', pickle.dumps(self._ips))
|
||||
self.storage.save_to_db('ips', pickle.dumps(self._ips))
|
||||
return b'\0'.join(
|
||||
[
|
||||
b'v7',
|
||||
@ -217,7 +217,7 @@ class IPMachinesService(IPServiceBase):
|
||||
|
||||
def unmarshal(self, data: bytes) -> None:
|
||||
values: list[bytes] = data.split(b'\0')
|
||||
d = self.storage.readData('ips')
|
||||
d = self.storage.read_from_db('ips')
|
||||
if isinstance(d, bytes):
|
||||
self._ips = pickle.loads(d) # nosec: pickle is safe here
|
||||
elif isinstance(d, str): # "legacy" saved elements
|
||||
@ -272,7 +272,7 @@ class IPMachinesService(IPServiceBase):
|
||||
for ip in allIps:
|
||||
theIP = IPServiceBase.getIp(ip)
|
||||
theMAC = IPServiceBase.getMac(ip)
|
||||
locked = self.storage.getPickle(theIP)
|
||||
locked = self.storage.get_unpickle(theIP)
|
||||
if self.canBeUsed(locked, now):
|
||||
if (
|
||||
self._port > 0
|
||||
@ -320,7 +320,7 @@ class IPMachinesService(IPServiceBase):
|
||||
logger.exception("Exception at getUnassignedMachine")
|
||||
|
||||
def enumerate_assignables(self):
|
||||
return [(ip, ip.split('~')[0]) for ip in self._ips if self.storage.readData(ip) is None]
|
||||
return [(ip, ip.split('~')[0]) for ip in self._ips if self.storage.read_from_db(ip) is None]
|
||||
|
||||
def assign_from_assignables(
|
||||
self,
|
||||
@ -333,7 +333,7 @@ class IPMachinesService(IPServiceBase):
|
||||
theMAC = IPServiceBase.getMac(assignableId)
|
||||
|
||||
now = sql_stamp_seconds()
|
||||
locked = self.storage.getPickle(theIP)
|
||||
locked = self.storage.get_unpickle(theIP)
|
||||
if self.canBeUsed(locked, now):
|
||||
self.storage.put_pickle(theIP, now)
|
||||
if theMAC:
|
||||
@ -351,7 +351,7 @@ class IPMachinesService(IPServiceBase):
|
||||
# Locate the IP on the storage
|
||||
theIP = IPServiceBase.getIp(id)
|
||||
now = sql_stamp_seconds()
|
||||
locked: typing.Union[None, str, int] = self.storage.getPickle(theIP)
|
||||
locked: typing.Union[None, str, int] = self.storage.get_unpickle(theIP)
|
||||
if self.canBeUsed(locked, now):
|
||||
self.storage.put_pickle(theIP, str(now)) # Lock it
|
||||
|
||||
@ -362,7 +362,7 @@ class IPMachinesService(IPServiceBase):
|
||||
logger.debug('Processing logout for %s: %s', self, id)
|
||||
# Locate the IP on the storage
|
||||
theIP = IPServiceBase.getIp(id)
|
||||
locked: typing.Union[None, str, int] = self.storage.getPickle(theIP)
|
||||
locked: typing.Union[None, str, int] = self.storage.get_unpickle(theIP)
|
||||
# If locked is str, has been locked by processLogin so we can unlock it
|
||||
if isinstance(locked, str):
|
||||
self.unassignMachine(id)
|
||||
|
@ -88,7 +88,7 @@ class IPSingleMachineService(IPServiceBase):
|
||||
def getUnassignedMachine(self) -> typing.Optional[str]:
|
||||
ip: typing.Optional[str] = None
|
||||
try:
|
||||
counter = self.storage.getPickle('counter')
|
||||
counter = self.storage.get_unpickle('counter')
|
||||
counter = counter + 1 if counter is not None else 1
|
||||
self.storage.put_pickle('counter', counter)
|
||||
ip = '{}~{}'.format(self.ip.value, counter)
|
||||
|
@ -513,7 +513,7 @@ if sys.platform == 'win32':
|
||||
"""
|
||||
Check if the machine has gracely stopped (timed shutdown)
|
||||
"""
|
||||
shutdown_start = self.storage.getPickle('shutdown')
|
||||
shutdown_start = self.storage.get_unpickle('shutdown')
|
||||
logger.debug('Shutdown start: %s', shutdown_start)
|
||||
if shutdown_start < 0: # Was already stopped
|
||||
# Machine is already stop
|
||||
|
@ -72,7 +72,7 @@ class ProxmoxDeferredRemoval(jobs.Job):
|
||||
except client.ProxmoxNotFound:
|
||||
return # Machine does not exists
|
||||
except Exception as e:
|
||||
providerInstance.storage.saveData('tr' + str(vmId), str(vmId), attr1='tRm')
|
||||
providerInstance.storage.save_to_db('tr' + str(vmId), str(vmId), attr1='tRm')
|
||||
logger.info(
|
||||
'Machine %s could not be removed right now, queued for later: %s',
|
||||
vmId,
|
||||
|
@ -117,13 +117,13 @@ class SampleUserServiceOne(services.UserService):
|
||||
a new unique name, so we keep the first generated name cached and don't
|
||||
generate more names. (Generator are simple utility classes)
|
||||
"""
|
||||
name: str = typing.cast(str, self.storage.readData('name'))
|
||||
name: str = typing.cast(str, self.storage.read_from_db('name'))
|
||||
if name is None:
|
||||
name = self.name_generator().get(
|
||||
self.service().get_base_name() + '-' + self.service().getColour(), 3
|
||||
)
|
||||
# Store value for persistence
|
||||
self.storage.saveData('name', name)
|
||||
self.storage.save_to_db('name', name)
|
||||
|
||||
return name
|
||||
|
||||
@ -139,7 +139,7 @@ class SampleUserServiceOne(services.UserService):
|
||||
:note: This IP is the IP of the "consumed service", so the transport can
|
||||
access it.
|
||||
"""
|
||||
self.storage.saveData('ip', ip)
|
||||
self.storage.save_to_db('ip', ip)
|
||||
|
||||
def get_unique_id(self) -> str:
|
||||
"""
|
||||
@ -151,10 +151,10 @@ class SampleUserServiceOne(services.UserService):
|
||||
The get method of a mac generator takes one param, that is the mac range
|
||||
to use to get an unused mac.
|
||||
"""
|
||||
mac = typing.cast(str, self.storage.readData('mac'))
|
||||
mac = typing.cast(str, self.storage.read_from_db('mac'))
|
||||
if mac is None:
|
||||
mac = self.mac_generator().get('00:00:00:00:00:00-00:FF:FF:FF:FF:FF')
|
||||
self.storage.saveData('mac', mac)
|
||||
self.storage.save_to_db('mac', mac)
|
||||
return mac
|
||||
|
||||
def get_ip(self) -> str:
|
||||
@ -175,7 +175,7 @@ class SampleUserServiceOne(services.UserService):
|
||||
show the IP to the administrator, this method will get called
|
||||
|
||||
"""
|
||||
ip = typing.cast(str, self.storage.readData('ip'))
|
||||
ip = typing.cast(str, self.storage.read_from_db('ip'))
|
||||
if ip is None:
|
||||
ip = '192.168.0.34' # Sample IP for testing purposses only
|
||||
return ip
|
||||
@ -241,11 +241,11 @@ class SampleUserServiceOne(services.UserService):
|
||||
"""
|
||||
import random
|
||||
|
||||
self.storage.saveData('count', '0')
|
||||
self.storage.save_to_db('count', '0')
|
||||
|
||||
# random fail
|
||||
if random.randint(0, 9) == 9: # nosec: just testing values
|
||||
self.storage.saveData('error', 'Random error at deployForUser :-)')
|
||||
self.storage.save_to_db('error', 'Random error at deployForUser :-)')
|
||||
return State.ERROR
|
||||
|
||||
return State.RUNNING
|
||||
@ -274,7 +274,7 @@ class SampleUserServiceOne(services.UserService):
|
||||
import random
|
||||
|
||||
countStr: typing.Optional[str] = typing.cast(
|
||||
str, self.storage.readData('count')
|
||||
str, self.storage.read_from_db('count')
|
||||
)
|
||||
count: int = 0
|
||||
if countStr:
|
||||
@ -288,10 +288,10 @@ class SampleUserServiceOne(services.UserService):
|
||||
|
||||
# random fail
|
||||
if random.randint(0, 9) == 9: # nosec: just testing values
|
||||
self.storage.saveData('error', 'Random error at check_state :-)')
|
||||
self.storage.save_to_db('error', 'Random error at check_state :-)')
|
||||
return State.ERROR
|
||||
|
||||
self.storage.saveData('count', str(count))
|
||||
self.storage.save_to_db('count', str(count))
|
||||
return State.RUNNING
|
||||
|
||||
def finish(self) -> None:
|
||||
@ -322,7 +322,7 @@ class SampleUserServiceOne(services.UserService):
|
||||
The user provided is just an string, that is provided by actor.
|
||||
"""
|
||||
# We store the value at storage, but never get used, just an example
|
||||
self.storage.saveData('user', username)
|
||||
self.storage.save_to_db('user', username)
|
||||
|
||||
def user_logged_out(self, username) -> None:
|
||||
"""
|
||||
@ -349,7 +349,7 @@ class SampleUserServiceOne(services.UserService):
|
||||
for it, and it will be asked everytime it's needed to be shown to the
|
||||
user (when the administation asks for it).
|
||||
"""
|
||||
return typing.cast(str, self.storage.readData('error')) or 'No error'
|
||||
return typing.cast(str, self.storage.read_from_db('error')) or 'No error'
|
||||
|
||||
def destroy(self) -> str:
|
||||
"""
|
||||
|
@ -388,7 +388,7 @@ class SampleUserServiceTwo(services.UserService):
|
||||
The user provided is just an string, that is provided by actors.
|
||||
"""
|
||||
# We store the value at storage, but never get used, just an example
|
||||
self.storage.saveData('user', username)
|
||||
self.storage.save_to_db('user', username)
|
||||
|
||||
def user_logged_out(self, username) -> None:
|
||||
"""
|
||||
|
@ -99,7 +99,7 @@ class TestUserService(services.UserService):
|
||||
|
||||
def get_ip(self) -> str:
|
||||
logger.info('Getting ip of deployment %s', self.data)
|
||||
ip = typing.cast(str, self.storage.readData('ip'))
|
||||
ip = typing.cast(str, self.storage.read_from_db('ip'))
|
||||
if ip is None:
|
||||
ip = '8.6.4.2' # Sample IP for testing purposses only
|
||||
return ip
|
||||
|
@ -183,7 +183,7 @@ urlpatterns = [
|
||||
# Services list, ...
|
||||
path(
|
||||
r'uds/webapi/services',
|
||||
uds.web.views.main.servicesData,
|
||||
uds.web.views.main.services_data_json,
|
||||
name='webapi.services',
|
||||
),
|
||||
# Transport own link processor
|
||||
|
@ -35,7 +35,7 @@ from django.http import HttpResponseRedirect
|
||||
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
from uds.core.auths.auth import authenticate, authenticate_log_login
|
||||
from uds.core.auths.auth import authenticate, log_login
|
||||
from uds.models import Authenticator, User
|
||||
from uds.core.util.config import GlobalConfig
|
||||
from uds.core.util.cache import Cache
|
||||
@ -114,13 +114,13 @@ def check_login( # pylint: disable=too-many-branches, too-many-statements
|
||||
and (tries >= maxTries)
|
||||
or triesByIp >= maxTries
|
||||
):
|
||||
authenticate_log_login(request, authenticator, userName, 'Temporarily blocked')
|
||||
log_login(request, authenticator, userName, 'Temporarily blocked')
|
||||
return LoginResult(
|
||||
errstr=_('Too many authentication errrors. User temporarily blocked')
|
||||
)
|
||||
# check if authenticator is visible for this requests
|
||||
if authInstance.is_ip_allowed(request=request) is False:
|
||||
authenticate_log_login(
|
||||
log_login(
|
||||
request,
|
||||
authenticator,
|
||||
userName,
|
||||
@ -136,7 +136,7 @@ def check_login( # pylint: disable=too-many-branches, too-many-statements
|
||||
logger.debug("Invalid user %s (access denied)", userName)
|
||||
cache.put(cacheKey, tries + 1, GlobalConfig.LOGIN_BLOCK.getInt())
|
||||
cache.put(request.ip, triesByIp + 1, GlobalConfig.LOGIN_BLOCK.getInt())
|
||||
authenticate_log_login(
|
||||
log_login(
|
||||
request,
|
||||
authenticator,
|
||||
userName,
|
||||
@ -154,7 +154,7 @@ def check_login( # pylint: disable=too-many-branches, too-many-statements
|
||||
if form.cleaned_data['logouturl'] != '':
|
||||
logger.debug('The logoout url will be %s', form.cleaned_data['logouturl'])
|
||||
request.session['logouturl'] = form.cleaned_data['logouturl']
|
||||
authenticate_log_login(request, authenticator, authResult.user.name)
|
||||
log_login(request, authenticator, authResult.user.name)
|
||||
return LoginResult(user=authResult.user, password=form.cleaned_data['password'])
|
||||
|
||||
logger.info('Invalid form received')
|
||||
|
@ -45,7 +45,7 @@ from uds.core.auths.auth import (
|
||||
web_login,
|
||||
web_logout,
|
||||
authenticate_via_callback,
|
||||
authenticate_log_login,
|
||||
log_login,
|
||||
getUDSCookie,
|
||||
)
|
||||
from uds.core.managers.user_service import UserServiceManager
|
||||
@ -111,14 +111,14 @@ def auth_callback_stage2(request: 'ExtendedHttpRequestWithUser', ticketId: str)
|
||||
raise exceptions.auth.Redirect(result.url)
|
||||
|
||||
if result.user is None:
|
||||
authenticate_log_login(request, authenticator, f'{params}', 'Invalid at auth callback')
|
||||
log_login(request, authenticator, f'{params}', 'Invalid at auth callback')
|
||||
raise exceptions.auth.InvalidUserException()
|
||||
|
||||
response = HttpResponseRedirect(reverse('page.index'))
|
||||
|
||||
web_login(request, response, result.user, '') # Password is unavailable in this case
|
||||
|
||||
authenticate_log_login(request, authenticator, result.user.name, 'Federated login')
|
||||
log_login(request, authenticator, result.user.name, 'Federated login')
|
||||
|
||||
# If MFA is provided, we need to redirect to MFA page
|
||||
request.authorized = True
|
||||
@ -239,7 +239,7 @@ def ticket_auth(
|
||||
web_login(request, None, usr, password)
|
||||
|
||||
# Log the login
|
||||
authenticate_log_login(request, auth, username, 'Ticket authentication')
|
||||
log_login(request, auth, username, 'Ticket authentication')
|
||||
|
||||
request.user = (
|
||||
usr # Temporarily store this user as "authenticated" user, next requests will be done using session
|
||||
@ -256,7 +256,9 @@ def ticket_auth(
|
||||
# Check if servicePool is part of the ticket
|
||||
if poolUuid:
|
||||
# Request service, with transport = None so it is automatic
|
||||
res = UserServiceManager().get_user_service_info(request.user, request.os, request.ip, poolUuid, None, False)
|
||||
res = UserServiceManager().get_user_service_info(
|
||||
request.user, request.os, request.ip, poolUuid, None, False
|
||||
)
|
||||
_, userService, _, transport, _ = res
|
||||
|
||||
transportInstance = transport.get_instance()
|
||||
|
@ -28,46 +28,39 @@
|
||||
"""
|
||||
Author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
"""
|
||||
import collections.abc
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
import logging
|
||||
import typing
|
||||
import collections.abc
|
||||
import random
|
||||
import json
|
||||
|
||||
from django.http import HttpRequest, HttpResponse, HttpResponseRedirect, JsonResponse
|
||||
from django.middleware import csrf
|
||||
from django.shortcuts import render
|
||||
from django.views.decorators.csrf import csrf_exempt
|
||||
from django.http import HttpRequest, HttpResponse, JsonResponse, HttpResponseRedirect
|
||||
from django.views.decorators.cache import never_cache
|
||||
from django.urls import reverse
|
||||
from django.utils.translation import gettext as _
|
||||
from py import log
|
||||
from uds.core.types.request import ExtendedHttpRequest
|
||||
from django.views.decorators.cache import never_cache
|
||||
from django.views.decorators.csrf import csrf_exempt
|
||||
|
||||
from uds.core.types.request import ExtendedHttpRequestWithUser
|
||||
from uds import auths, models
|
||||
from uds.core import exceptions, mfas, types, consts
|
||||
from uds.core.auths import auth
|
||||
from uds.core.util import state
|
||||
from uds.core.util.config import GlobalConfig
|
||||
from uds.core.managers.crypto import CryptoManager
|
||||
from uds.core.managers.user_service import UserServiceManager
|
||||
from uds.web.util import errors
|
||||
from uds.core.types.request import ExtendedHttpRequest, ExtendedHttpRequestWithUser
|
||||
from uds.core.util import config, state, storage
|
||||
from uds.core.util.model import sql_stamp_seconds
|
||||
from uds.web.forms.LoginForm import LoginForm
|
||||
from uds.web.forms.MFAForm import MFAForm
|
||||
from uds.web.util import configjs, errors
|
||||
from uds.web.util.authentication import check_login
|
||||
from uds.web.util.services import getServicesData
|
||||
from uds.web.util import configjs
|
||||
from uds.core import mfas, types, exceptions
|
||||
from uds import auths, models
|
||||
from uds.core.util.model import sql_stamp_seconds
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CSRF_FIELD = 'csrfmiddlewaretoken'
|
||||
MFA_COOKIE_NAME = 'mfa_status'
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
@ -82,7 +75,7 @@ def index(request: HttpRequest) -> HttpResponse:
|
||||
response = render(
|
||||
request=request,
|
||||
template_name='uds/modern/index.html',
|
||||
context={'csrf_field': CSRF_FIELD, 'csrf_token': csrf_token},
|
||||
context={'csrf_field': consts.auth.CSRF_FIELD, 'csrf_token': csrf_token},
|
||||
)
|
||||
|
||||
# Ensure UDS cookie is present
|
||||
@ -153,7 +146,7 @@ def login(request: ExtendedHttpRequest, tag: typing.Optional[str] = None) -> Htt
|
||||
@never_cache
|
||||
@auth.web_login_required(admin=False)
|
||||
def logout(request: ExtendedHttpRequestWithUser) -> HttpResponse:
|
||||
auth.auth_log_logout(request)
|
||||
auth.log_logout(request)
|
||||
request.session['restricted'] = False # Remove restricted
|
||||
request.authorized = False
|
||||
logoutResponse = request.user.logout(request)
|
||||
@ -169,11 +162,11 @@ def js(request: ExtendedHttpRequest) -> HttpResponse:
|
||||
|
||||
@never_cache
|
||||
@auth.deny_non_authenticated # web_login_required not used here because this is not a web page, but js
|
||||
def servicesData(request: ExtendedHttpRequestWithUser) -> HttpResponse:
|
||||
def services_data_json(request: ExtendedHttpRequestWithUser) -> HttpResponse:
|
||||
return JsonResponse(getServicesData(request))
|
||||
|
||||
|
||||
# The MFA page does not needs CRF token, so we disable it
|
||||
# The MFA page does not needs CSRF token, so we disable it
|
||||
@csrf_exempt
|
||||
def mfa(
|
||||
request: ExtendedHttpRequest,
|
||||
@ -182,26 +175,37 @@ def mfa(
|
||||
logger.warning('MFA: No user or user is already authorized')
|
||||
return HttpResponseRedirect(reverse('page.index')) # No user, no MFA
|
||||
|
||||
mfaProvider = typing.cast('None|models.MFA', request.user.manager.mfa)
|
||||
if not mfaProvider:
|
||||
store: 'storage.Storage' = storage.Storage('mfs')
|
||||
|
||||
mfa_provider = typing.cast('None|models.MFA', request.user.manager.mfa)
|
||||
if not mfa_provider:
|
||||
logger.warning('MFA: No MFA provider for user')
|
||||
return HttpResponseRedirect(reverse('page.index'))
|
||||
|
||||
mfaUserId = mfas.MFA.get_user_id(request.user)
|
||||
mfa_user_id = mfas.MFA.get_user_id(request.user)
|
||||
|
||||
# Try to get cookie anc check it
|
||||
mfaCookie = request.COOKIES.get(MFA_COOKIE_NAME, None)
|
||||
if mfaCookie == mfaUserId: # Cookie is valid, skip MFA setting authorization
|
||||
logger.debug('MFA: Cookie is valid, skipping MFA')
|
||||
request.authorized = True
|
||||
return HttpResponseRedirect(reverse('page.index'))
|
||||
mfa_cookie = request.COOKIES.get(consts.auth.MFA_COOKIE_NAME, None)
|
||||
if mfa_cookie and mfa_provider.remember_device > 0:
|
||||
stored_user_id: typing.Optional[str]
|
||||
created: typing.Optional[datetime.datetime]
|
||||
stored_user_id, created = store.get_unpickle(mfa_cookie) or (None, None)
|
||||
if (
|
||||
stored_user_id
|
||||
and created
|
||||
and created + datetime.timedelta(hours=mfa_provider.remember_device) > datetime.datetime.now()
|
||||
):
|
||||
# Cookie is valid, skip MFA setting authorization
|
||||
logger.debug('MFA: Cookie is valid, skipping MFA')
|
||||
request.authorized = True
|
||||
return HttpResponseRedirect(reverse('page.index'))
|
||||
|
||||
# Obtain MFA data
|
||||
authInstance = request.user.manager.get_instance()
|
||||
mfaInstance = typing.cast('mfas.MFA', mfaProvider.get_instance())
|
||||
auth_instance = request.user.manager.get_instance()
|
||||
mfa_instance = typing.cast('mfas.MFA', mfa_provider.get_instance())
|
||||
|
||||
# Get validity duration
|
||||
validity = mfaProvider.validity * 60
|
||||
validity = mfa_provider.validity * 60
|
||||
now = sql_stamp_seconds()
|
||||
start_time = request.session.get('mfa_start_time', now)
|
||||
|
||||
@ -211,26 +215,26 @@ def mfa(
|
||||
request.session.flush() # Clear session, and redirect to login
|
||||
return HttpResponseRedirect(reverse('page.login'))
|
||||
|
||||
mfaIdentifier = authInstance.mfa_identifier(request.user.name)
|
||||
label = mfaInstance.label()
|
||||
mfa_identifier = auth_instance.mfa_identifier(request.user.name)
|
||||
label = mfa_instance.label()
|
||||
|
||||
if not mfaIdentifier:
|
||||
emtpyIdentifiedAllowed = mfaInstance.allow_login_without_identifier(request)
|
||||
if not mfa_identifier:
|
||||
allow_login_without_identifier = mfa_instance.allow_login_without_identifier(request)
|
||||
# can be True, False or None
|
||||
if emtpyIdentifiedAllowed is True:
|
||||
if allow_login_without_identifier is True:
|
||||
# Allow login
|
||||
request.authorized = True
|
||||
return HttpResponseRedirect(reverse('page.index'))
|
||||
if emtpyIdentifiedAllowed is False:
|
||||
if allow_login_without_identifier is False:
|
||||
# Not allowed to login, redirect to login error page
|
||||
logger.warning(
|
||||
'MFA identifier not found for user %s on authenticator %s. It is required by MFA %s',
|
||||
request.user.name,
|
||||
request.user.manager.name,
|
||||
mfaProvider.name,
|
||||
mfa_provider.name,
|
||||
)
|
||||
return errors.errorView(request, errors.ACCESS_DENIED)
|
||||
# None, the authenticator will decide what to do if mfaIdentifier is empty
|
||||
# None, the authenticator will decide what to do if mfa_identifier is empty
|
||||
|
||||
tries = request.session.get('mfa_tries', 0)
|
||||
if request.method == 'POST': # User has provided MFA code
|
||||
@ -238,11 +242,11 @@ def mfa(
|
||||
if form.is_valid():
|
||||
code = form.cleaned_data['code']
|
||||
try:
|
||||
mfaInstance.validate(
|
||||
mfa_instance.validate(
|
||||
request,
|
||||
mfaUserId,
|
||||
mfa_user_id,
|
||||
request.user.name,
|
||||
mfaIdentifier,
|
||||
mfa_identifier,
|
||||
code,
|
||||
validity=validity,
|
||||
) # Will raise MFAError if code is not valid
|
||||
@ -256,11 +260,17 @@ def mfa(
|
||||
response = HttpResponseRedirect(reverse('page.index'))
|
||||
|
||||
# If mfaProvider requests to keep MFA code on client, create a mfacookie for this user
|
||||
if mfaProvider.remember_device > 0 and form.cleaned_data['remember'] is True:
|
||||
if mfa_provider.remember_device > 0 and form.cleaned_data['remember'] is True:
|
||||
# Store also cookie locally, to check if remember_device is changed
|
||||
mfa_cookie = CryptoManager().random_string(96)
|
||||
store.put_pickle(
|
||||
mfa_cookie,
|
||||
(mfa_user_id, now),
|
||||
)
|
||||
response.set_cookie(
|
||||
MFA_COOKIE_NAME,
|
||||
mfaUserId,
|
||||
max_age=mfaProvider.remember_device * 60 * 60,
|
||||
consts.auth.MFA_COOKIE_NAME,
|
||||
mfa_cookie,
|
||||
max_age=mfa_provider.remember_device * 60 * 60,
|
||||
)
|
||||
|
||||
return response
|
||||
@ -268,7 +278,7 @@ def mfa(
|
||||
logger.error('MFA error: %s', e)
|
||||
tries += 1
|
||||
request.session['mfa_tries'] = tries
|
||||
if tries >= GlobalConfig.MAX_LOGIN_TRIES.getInt():
|
||||
if tries >= config.GlobalConfig.MAX_LOGIN_TRIES.getInt():
|
||||
# Clean session
|
||||
request.session.flush()
|
||||
# Too many tries, redirect to login error page
|
||||
@ -280,11 +290,11 @@ def mfa(
|
||||
# Make MFA send a code
|
||||
request.session['mfa_tries'] = 0 # Reset tries
|
||||
try:
|
||||
result = mfaInstance.process(
|
||||
result = mfa_instance.process(
|
||||
request,
|
||||
mfaUserId,
|
||||
mfa_user_id,
|
||||
request.user.name,
|
||||
mfaIdentifier,
|
||||
mfa_identifier,
|
||||
validity=validity,
|
||||
)
|
||||
if result == mfas.MFA.RESULT.ALLOWED:
|
||||
@ -302,15 +312,15 @@ def mfa(
|
||||
# Compose a nice "XX years, XX months, XX days, XX hours, XX minutes" string from mfaProvider.remember_device
|
||||
remember_device = ''
|
||||
# Remember_device is in hours
|
||||
if mfaProvider.remember_device > 0:
|
||||
if mfa_provider.remember_device > 0:
|
||||
# if more than a day, we show days only
|
||||
if mfaProvider.remember_device >= 24:
|
||||
remember_device = _('{} days').format(mfaProvider.remember_device // 24)
|
||||
if mfa_provider.remember_device >= 24:
|
||||
remember_device = _('{} days').format(mfa_provider.remember_device // 24)
|
||||
else:
|
||||
remember_device = _('{} hours').format(mfaProvider.remember_device)
|
||||
remember_device = _('{} hours').format(mfa_provider.remember_device)
|
||||
|
||||
# Html from MFA provider
|
||||
mfaHtml = mfaInstance.html(request, mfaUserId, request.user.name)
|
||||
mfaHtml = mfa_instance.html(request, mfa_user_id, request.user.name)
|
||||
|
||||
# Redirect to index, but with MFA data
|
||||
request.session['mfa'] = {
|
||||
|
@ -43,24 +43,24 @@ class StorageTest(UDSTestCase):
|
||||
storage = Storage(UNICODE_CHARS)
|
||||
|
||||
storage.put(UNICODE_CHARS, b'chars')
|
||||
storage.saveData('saveData', UNICODE_CHARS, UNICODE_CHARS)
|
||||
storage.saveData('saveData2', UNICODE_CHARS_2, UNICODE_CHARS)
|
||||
storage.saveData('saveData3', UNICODE_CHARS, 'attribute')
|
||||
storage.saveData('saveData4', UNICODE_CHARS_2, 'attribute')
|
||||
storage.save_to_db('saveData', UNICODE_CHARS, UNICODE_CHARS)
|
||||
storage.save_to_db('saveData2', UNICODE_CHARS_2, UNICODE_CHARS)
|
||||
storage.save_to_db('saveData3', UNICODE_CHARS, 'attribute')
|
||||
storage.save_to_db('saveData4', UNICODE_CHARS_2, 'attribute')
|
||||
storage.put(b'key', UNICODE_CHARS)
|
||||
storage.put(UNICODE_CHARS_2, UNICODE_CHARS)
|
||||
|
||||
storage.put_pickle('pickle', VALUE_1)
|
||||
|
||||
self.assertEqual(storage.get(UNICODE_CHARS), u'chars') # Always returns unicod
|
||||
self.assertEqual(storage.readData('saveData'), UNICODE_CHARS)
|
||||
self.assertEqual(storage.readData('saveData2'), UNICODE_CHARS_2)
|
||||
self.assertEqual(storage.read_from_db('saveData'), UNICODE_CHARS)
|
||||
self.assertEqual(storage.read_from_db('saveData2'), UNICODE_CHARS_2)
|
||||
self.assertEqual(storage.get(b'key'), UNICODE_CHARS)
|
||||
self.assertEqual(storage.get(UNICODE_CHARS_2), UNICODE_CHARS)
|
||||
self.assertEqual(storage.getPickle('pickle'), VALUE_1)
|
||||
self.assertEqual(storage.get_unpickle('pickle'), VALUE_1)
|
||||
|
||||
self.assertEqual(len(list(storage.locateByAttr1(UNICODE_CHARS))), 2)
|
||||
self.assertEqual(len(list(storage.locateByAttr1('attribute'))), 2)
|
||||
self.assertEqual(len(list(storage.search_by_attr1(UNICODE_CHARS))), 2)
|
||||
self.assertEqual(len(list(storage.search_by_attr1('attribute'))), 2)
|
||||
|
||||
storage.remove(UNICODE_CHARS)
|
||||
storage.remove(b'key')
|
||||
@ -68,4 +68,4 @@ class StorageTest(UDSTestCase):
|
||||
|
||||
self.assertIsNone(storage.get(UNICODE_CHARS))
|
||||
self.assertIsNone(storage.get(b'key'))
|
||||
self.assertIsNone(storage.getPickle('pickle'))
|
||||
self.assertIsNone(storage.get_unpickle('pickle'))
|
||||
|
Loading…
Reference in New Issue
Block a user