mirror of
https://github.com/dkmstr/openuds.git
synced 2024-12-22 13:34:04 +03:00
Added ticket methods to Servers Rest Api
This commit is contained in:
parent
0eef9c2f09
commit
6e40e56d24
@ -33,6 +33,7 @@ import logging
|
||||
|
||||
from unittest import mock
|
||||
|
||||
from uds import models
|
||||
from uds.core.util import log
|
||||
|
||||
from ...utils import rest, random_ip_v4, random_ip_v6, random_mac
|
||||
@ -40,7 +41,6 @@ from ...fixtures import servers as servers_fixtures
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ...utils.test import UDSHttpResponse
|
||||
from uds import models
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -94,7 +94,21 @@ class ServerEventsLoginLogoutTest(rest.test.RESTTestCase):
|
||||
self.assertEqual(session.session_id, result['session_id'])
|
||||
self.assertEqual(self.user_service_managed.properties.get('last_username', ''), 'local_user_name')
|
||||
|
||||
# TODO: Finish this test
|
||||
def test_login_with_ticket(self) -> None:
|
||||
ticket_uuid = models.TicketStore.create({'user_service': self.user_service_managed.uuid, 'some_value': 'value'})
|
||||
response = self.client.rest_post(
|
||||
'/servers/event',
|
||||
data={
|
||||
'token': self.server.token,
|
||||
'type': 'login',
|
||||
'user_service': self.user_service_managed.uuid,
|
||||
'username': 'local_user_name',
|
||||
'ticket': ticket_uuid,
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
data = response.json()['result']
|
||||
self.assertEqual(data['ticket']['some_value'], 'value')
|
||||
|
||||
def test_login_fail(self) -> None:
|
||||
response = self.client.rest_post(
|
||||
@ -191,3 +205,17 @@ class ServerEventsLoginLogoutTest(rest.test.RESTTestCase):
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.user_service_managed.refresh_from_db()
|
||||
self.assertEqual(self.user_service_managed.in_use, False)
|
||||
|
||||
def test_ticket(self) -> None:
|
||||
ticket_uuid = models.TicketStore.create({'user_service': self.user_service_managed.uuid, 'some_value': 'value'})
|
||||
response = self.client.rest_post(
|
||||
'/servers/event',
|
||||
data={
|
||||
'token': self.server.token,
|
||||
'type': 'ticket',
|
||||
'ticket': ticket_uuid,
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
data = response.json()['result']
|
||||
self.assertEqual(data['some_value'], 'value')
|
@ -70,7 +70,8 @@ VALID_PARAMS = (
|
||||
class Tickets(Handler):
|
||||
"""
|
||||
Processes tickets access requests.
|
||||
Tickets are element used to "register" & "allow access" to users.
|
||||
Tickets are element used to "register" & "allow access" to users to a service.
|
||||
Designed to be used by external systems (like web services) to allow access to users to services.
|
||||
|
||||
The rest API accepts the following parameters:
|
||||
authId: uuid of the authenticator for the user | Mutually excluyents
|
||||
|
@ -383,4 +383,4 @@ class ServerManager(metaclass=singleton.Singleton):
|
||||
That is, this is not invoked directly unless a REST request is received from
|
||||
a server.
|
||||
"""
|
||||
return events.process(data)
|
||||
return events.process(server, data)
|
||||
|
@ -40,21 +40,42 @@ from uds.REST.utils import rest_result
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def process_log(data: typing.Dict[str, typing.Any]) -> typing.Any:
|
||||
def process_log(server: 'models.Server', data: typing.Dict[str, typing.Any]) -> typing.Any:
|
||||
if 'user_service' in data: # Log for an user service
|
||||
userService = models.UserService.objects.get(uuid=data['user_service'])
|
||||
log.doLog(
|
||||
userService, log.LogLevel.fromStr(data['level']), data['message'], source=log.LogSource.SERVER
|
||||
)
|
||||
else:
|
||||
server = models.Server.objects.get(token=data['token'])
|
||||
log.doLog(server, log.LogLevel.fromStr(data['level']), data['message'], source=log.LogSource.SERVER)
|
||||
|
||||
return rest_result(consts.OK)
|
||||
|
||||
|
||||
def process_login(data: typing.Dict[str, typing.Any]) -> typing.Any:
|
||||
server = models.Server.objects.get(token=data['token'])
|
||||
def process_login(server: 'models.Server', data: typing.Dict[str, typing.Any]) -> typing.Any:
|
||||
"""Processes the REST login event from a server
|
||||
|
||||
data: {
|
||||
'user_service': 'uuid of user service',
|
||||
'username': 'username',
|
||||
'session_id': 'session id',
|
||||
'ticket': 'ticket if any' # optional
|
||||
}
|
||||
|
||||
Returns a dict with the following keys:
|
||||
|
||||
{
|
||||
'ip': 'ip of connection origin',
|
||||
'hostname': 'hostname of connection origin',
|
||||
'dead_line': 'dead line of service', # The point in time when the service will be automatically removed, optional (None if not set)
|
||||
'max_idle': 'max idle time of service', # The max time the service can be idle before being removed, optional (None if not set)
|
||||
'session_id': 'session id', # The session id assigned to this login
|
||||
'ticket': 'ticket if any' # optional
|
||||
|
||||
}
|
||||
|
||||
|
||||
"""
|
||||
userService = models.UserService.objects.get(uuid=data['user_service'])
|
||||
server.setActorVersion(userService)
|
||||
|
||||
@ -73,19 +94,30 @@ def process_login(data: typing.Dict[str, typing.Any]) -> typing.Any:
|
||||
deadLine = deadLine = (
|
||||
userService.deployed_service.getDeadline() if not osManager or osManager.ignoreDeadLine() else None
|
||||
)
|
||||
result = {
|
||||
'ip': src.ip,
|
||||
'hostname': src.hostname,
|
||||
'dead_line': deadLine,
|
||||
'max_idle': maxIdle,
|
||||
'session_id': session_id,
|
||||
}
|
||||
|
||||
return rest_result(
|
||||
{
|
||||
'ip': src.ip,
|
||||
'hostname': src.hostname,
|
||||
'dead_line': deadLine,
|
||||
'max_idle': maxIdle,
|
||||
'session_id': session_id,
|
||||
}
|
||||
)
|
||||
if 'ticket' in data:
|
||||
result['ticket'] = models.TicketStore.get(data['ticket'], invalidate=True)
|
||||
|
||||
return rest_result(result)
|
||||
|
||||
|
||||
def process_logout(data: typing.Dict[str, typing.Any]) -> typing.Any:
|
||||
def process_logout(server: 'models.Server', data: typing.Dict[str, typing.Any]) -> typing.Any:
|
||||
"""Processes the REST logout event from a server
|
||||
|
||||
data: {
|
||||
'user_service': 'uuid of user service',
|
||||
'session_id': 'session id',
|
||||
}
|
||||
|
||||
Returns 'OK' if all went ok ({'result': 'OK', 'stamp': 'stamp'}), or an error if not ({'result': 'error', 'error': 'error description'}})
|
||||
"""
|
||||
userService = models.UserService.objects.get(uuid=data['user_service'])
|
||||
|
||||
session_id = data['session_id']
|
||||
@ -101,8 +133,7 @@ def process_logout(data: typing.Dict[str, typing.Any]) -> typing.Any:
|
||||
return rest_result(consts.OK)
|
||||
|
||||
|
||||
def process_ping(data: typing.Dict[str, typing.Any]) -> typing.Any:
|
||||
server = models.Server.objects.get(token=data['token'])
|
||||
def process_ping(server: 'models.Server', data: typing.Dict[str, typing.Any]) -> typing.Any:
|
||||
if 'stats' in data:
|
||||
server.stats = types.servers.ServerStatsType.fromDict(data['stats'])
|
||||
# Set stats on server
|
||||
@ -111,21 +142,29 @@ def process_ping(data: typing.Dict[str, typing.Any]) -> typing.Any:
|
||||
return rest_result(consts.OK)
|
||||
|
||||
|
||||
PROCESSORS: typing.Final[typing.Mapping[str, typing.Callable[[typing.Dict[str, typing.Any]], typing.Any]]] = {
|
||||
def process_ticket(server: 'models.Server', data: typing.Dict[str, typing.Any]) -> typing.Any:
|
||||
return rest_result(models.TicketStore.get(data['ticket'], invalidate=True))
|
||||
|
||||
|
||||
PROCESSORS: typing.Final[
|
||||
typing.Mapping[str, typing.Callable[['models.Server', typing.Dict[str, typing.Any]], typing.Any]]
|
||||
] = {
|
||||
'log': process_log,
|
||||
'login': process_login,
|
||||
'logout': process_logout,
|
||||
'ping': process_ping,
|
||||
'ticket': process_ticket,
|
||||
}
|
||||
|
||||
|
||||
def process(data: typing.Dict[str, typing.Any]) -> typing.Any:
|
||||
def process(server: 'models.Server', data: typing.Dict[str, typing.Any]) -> typing.Any:
|
||||
"""Processes the event data
|
||||
Valid events are (in key 'type'):
|
||||
* log: A log message (to server or userService)
|
||||
* login: A login has been made (to an userService)
|
||||
* logout: A logout has been made (to an userService)
|
||||
* ping: A ping request (can include stats, etc...)
|
||||
* ticket: A ticket to obtain it's data
|
||||
"""
|
||||
try:
|
||||
fnc = PROCESSORS[data['type']]
|
||||
@ -134,7 +173,7 @@ def process(data: typing.Dict[str, typing.Any]) -> typing.Any:
|
||||
return rest_result('error', error=f'Invalid event type {data.get("type", "not_found")}')
|
||||
|
||||
try:
|
||||
return fnc(data)
|
||||
return fnc(server, data)
|
||||
except Exception as e:
|
||||
logger.error('Exception processing event %s: %s', data, e)
|
||||
return rest_result('error', error=str(e))
|
||||
|
@ -51,6 +51,7 @@ SECURED = '#SECURE#' # Just a "different" owner. If used anywhere, it's not imp
|
||||
# Note that the tunnel ticket will be the owner + the ticket itself, so it will be 48 chars long (Secured or not)
|
||||
TICKET_LENGTH = 40 # Ticket length must much the length of the ticket length on tunnel server!!! (take care with previous note)
|
||||
|
||||
|
||||
class TicketStore(UUIDModel):
|
||||
"""
|
||||
Tickets storing on DB
|
||||
@ -62,14 +63,10 @@ class TicketStore(UUIDModel):
|
||||
|
||||
owner = models.CharField(null=True, blank=True, default=None, max_length=8)
|
||||
stamp = models.DateTimeField() # Date creation or validation of this entry
|
||||
validity = models.IntegerField(
|
||||
default=60
|
||||
) # Duration allowed for this ticket to be valid, in seconds
|
||||
validity = models.IntegerField(default=60) # Duration allowed for this ticket to be valid, in seconds
|
||||
|
||||
data = models.BinaryField() # Associated ticket data
|
||||
validator = models.BinaryField(
|
||||
null=True, blank=True, default=None
|
||||
) # Associated validator for this ticket
|
||||
validator = models.BinaryField(null=True, blank=True, default=None) # Associated validator for this ticket
|
||||
|
||||
# "fake" declarations for type checking
|
||||
# objects: 'models.manager.Manager[TicketStore]'
|
||||
@ -87,9 +84,7 @@ class TicketStore(UUIDModel):
|
||||
|
||||
@staticmethod
|
||||
def generateUuid() -> str:
|
||||
return (
|
||||
CryptoManager().randomString(TICKET_LENGTH).lower()
|
||||
) # Temporary fix lower() for compat with 3.0
|
||||
return CryptoManager().randomString(TICKET_LENGTH).lower() # Temporary fix lower() for compat with 3.0
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
@ -99,8 +94,17 @@ class TicketStore(UUIDModel):
|
||||
owner: typing.Optional[str] = None,
|
||||
secure: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
validity is in seconds
|
||||
"""Creates a ticket (used to store data that can be retrieved later using REST API, for example)
|
||||
|
||||
Args:
|
||||
data: Data to store on ticket
|
||||
validatorFnc: Optional validator function. If present, it will be called with the data as parameter. If it returns False, the ticket will be invalid
|
||||
validity: Validity of the ticket, in seconds
|
||||
owner: Optional owner of the ticket. If present, only the owner can retrieve the ticket
|
||||
secure: If true, the data will be encrypted using the owner as key. If owner is not present, an exception will be raised
|
||||
|
||||
Returns:
|
||||
The ticket id
|
||||
"""
|
||||
validator = pickle.dumps(validatorFnc) if validatorFnc else None
|
||||
|
||||
@ -110,19 +114,16 @@ class TicketStore(UUIDModel):
|
||||
if not owner:
|
||||
raise ValueError('Tried to use a secure ticket without owner')
|
||||
data = CryptoManager().AESCrypt(data, owner.encode())
|
||||
owner = SECURED # So data is REALLY encrypted
|
||||
owner = SECURED # So data is REALLY encrypted, because key used to encrypt is sustituted by SECURED
|
||||
|
||||
return (
|
||||
TicketStore.objects.create(
|
||||
uuid=TicketStore.generateUuid(),
|
||||
stamp=getSqlDatetime(),
|
||||
data=data,
|
||||
validator=validator,
|
||||
validity=validity,
|
||||
owner=owner,
|
||||
).uuid
|
||||
or ''
|
||||
)
|
||||
return TicketStore.objects.create(
|
||||
uuid=TicketStore.generateUuid(),
|
||||
stamp=getSqlDatetime(),
|
||||
data=data,
|
||||
validator=validator,
|
||||
validity=validity,
|
||||
owner=owner,
|
||||
).uuid
|
||||
|
||||
@staticmethod
|
||||
def get(
|
||||
@ -149,13 +150,9 @@ class TicketStore(UUIDModel):
|
||||
data: bytes = t.data
|
||||
|
||||
if secure: # Owner has already been tested and it's not emtpy
|
||||
data = CryptoManager().AESDecrypt(
|
||||
data, typing.cast(str, owner).encode()
|
||||
)
|
||||
data = CryptoManager().AESDecrypt(data, typing.cast(str, owner).encode())
|
||||
|
||||
data = pickle.loads(
|
||||
data
|
||||
) # nosec: Tickets are generated by us, so we know they are safe
|
||||
data = pickle.loads(data) # nosec: Tickets are generated by us, so we know they are safe
|
||||
|
||||
# If has validator, execute it
|
||||
if t.validator:
|
||||
@ -190,13 +187,9 @@ class TicketStore(UUIDModel):
|
||||
if secure: # Owner has already been tested and it's not emtpy
|
||||
if not owner:
|
||||
raise ValueError('Tried to use a secure ticket without owner')
|
||||
data = CryptoManager().AESDecrypt(
|
||||
data, typing.cast(str, owner).encode()
|
||||
)
|
||||
data = CryptoManager().AESDecrypt(data, typing.cast(str, owner).encode())
|
||||
|
||||
dct = pickle.loads(
|
||||
data
|
||||
) # nosec: Tickets are ONLY generated by us, so we know they are safe
|
||||
dct = pickle.loads(data) # nosec: Tickets are ONLY generated by us, so we know they are safe
|
||||
|
||||
# invoke check function
|
||||
if checkFnc(dct) is False:
|
||||
@ -311,11 +304,7 @@ class TicketStore(UUIDModel):
|
||||
|
||||
def __str__(self) -> str:
|
||||
# Tickets are generated by us, so we know they are safe
|
||||
data = (
|
||||
pickle.loads(self.data) # nosec
|
||||
if self.owner != SECURED
|
||||
else '{Secure Ticket}'
|
||||
)
|
||||
data = pickle.loads(self.data) if self.owner != SECURED else '{Secure Ticket}' # nosec
|
||||
|
||||
return (
|
||||
f'Ticket id: {self.uuid}, Owner: {self.owner}, Stamp: {self.stamp}, '
|
||||
|
@ -54,13 +54,10 @@ class UUIDModel(models.Model):
|
||||
class Meta: # pylint: disable=too-few-public-methods
|
||||
abstract = True
|
||||
|
||||
def genUuid(self) -> str:
|
||||
return generateUuid()
|
||||
|
||||
# Override default save to add uuid
|
||||
def save(self, *args, **kwargs):
|
||||
if not self.uuid:
|
||||
self.uuid = self.genUuid()
|
||||
self.uuid = generateUuid()
|
||||
elif self.uuid != self.uuid.lower():
|
||||
self.uuid = (
|
||||
self.uuid.lower()
|
||||
|
Loading…
Reference in New Issue
Block a user