1
0
mirror of https://github.com/dkmstr/openuds.git synced 2024-12-22 13:34:04 +03:00

upgrading and linting tunnel

This commit is contained in:
Adolfo Gómez García 2023-05-16 01:09:32 +02:00
parent ddf07eb68b
commit 4d26df9580
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
2 changed files with 22 additions and 39 deletions

View File

@ -28,6 +28,7 @@
'''
Author: Adolfo Gómez, dkmaster at dkmon dot com
'''
import re
import typing
DEBUG = True
@ -71,3 +72,6 @@ RESPONSE_OK: typing.Final[bytes] = b'OK'
# Backlog for listen socket
BACKLOG = 1024
# Regular expression for parsing ticket
TICKET_REGEX = re.compile(f'^[a-zA-Z0-9]{{{TICKET_LENGTH}}}$')

View File

@ -78,9 +78,7 @@ class TunnelProtocol(asyncio.Protocol):
# If there is a timeout task running
timeout_task: typing.Optional[asyncio.Task] = None
def __init__(
self, owner: 'proxy.Proxy'
) -> None:
def __init__(self, owner: 'proxy.Proxy') -> None:
# If no other side is given, we are the server part
super().__init__()
# transport is undefined until connection_made is called
@ -91,7 +89,7 @@ class TunnelProtocol(asyncio.Protocol):
self.destination = ('', 0)
self.tls_version = ''
self.tls_cipher = ''
# If other_side is given, we are the client part (that is, the tunnel from us to remote machine)
# In this case, only do_proxy is used
self.client = None
@ -124,9 +122,7 @@ class TunnelProtocol(asyncio.Protocol):
async def open_client() -> None:
try:
result = await TunnelProtocol.get_ticket_from_uds(
self.owner.cfg, ticket, self.source
)
result = await TunnelProtocol.get_ticket_from_uds(self.owner.cfg, ticket, self.source)
except Exception as e:
logger.error('ERROR %s', e.args[0] if e.args else e)
self.transport.write(consts.RESPONSE_ERROR_TICKET)
@ -146,8 +142,7 @@ class TunnelProtocol(asyncio.Protocol):
try:
family = (
socket.AF_INET6
if ':' in self.destination[0]
or (self.owner.cfg.ipv6 and not '.' in self.destination[0])
if ':' in self.destination[0] or (self.owner.cfg.ipv6 and '.' not in self.destination[0])
else socket.AF_INET
)
(_, self.client) = await loop.create_connection(
@ -161,7 +156,7 @@ class TunnelProtocol(asyncio.Protocol):
self.transport.resume_reading()
# send OK to client
self.transport.write(b'OK')
self.stats_manager.increment_connections() # Increment connections counters
self.stats_manager.increment_connections() # Increment connections counters
except Exception as e:
logger.error('Error opening connection: %s', e)
self.close_connection()
@ -171,7 +166,7 @@ class TunnelProtocol(asyncio.Protocol):
# From now, proxy connection
self.runner = self.do_proxy
def process_stats(self, full: bool) -> None:
def process_stats(self, full: bool) -> None: # pylint: disable=unused-argument
# if pasword is not already received, wait for it
if len(self.cmd) < consts.PASSWORD_LENGTH + consts.COMMAND_LENGTH:
return
@ -246,22 +241,22 @@ class TunnelProtocol(asyncio.Protocol):
try:
if command == consts.COMMAND_OPEN:
self.process_open()
elif command == consts.COMMAND_TEST:
return
if command == consts.COMMAND_TEST:
self.clean_timeout() # Stop timeout
logger.info('COMMAND: TEST')
self.transport.write(consts.RESPONSE_OK)
self.close_connection()
return
elif command in (consts.COMMAND_STAT, consts.COMMAND_INFO):
if command in (consts.COMMAND_STAT, consts.COMMAND_INFO):
# This is an stats requests
try:
self.process_stats(full=command == consts.COMMAND_STAT)
self.process_stats(full=command == consts.COMMAND_STAT)
except Exception as e:
logger.error('ERROR processing stats: %s', e.args[0] if e.args else e)
self.close_connection()
return
else:
raise Exception('Invalid command')
raise Exception('Invalid command')
except Exception:
logger.error('ERROR from %s', self.pretty_source())
self.transport.write(consts.RESPONSE_ERROR_COMMAND)
@ -298,9 +293,7 @@ class TunnelProtocol(asyncio.Protocol):
)
# Notify end to uds, using a task becase we are not an async function
asyncio.get_event_loop().create_task(
TunnelProtocol.notify_end_to_uds(
self.owner.cfg, self.notify_ticket, self.stats_manager
)
TunnelProtocol.notify_end_to_uds(self.owner.cfg, self.notify_ticket, self.stats_manager)
)
self.notify_ticket = b'' # Clean up so no more notifications
else:
@ -350,7 +343,6 @@ class TunnelProtocol(asyncio.Protocol):
def pretty_destination(self) -> str:
return TunnelProtocol.pretty_address(self.destination)
@staticmethod
async def _read_from_uds(
cfg: config.ConfigurationType,
@ -359,13 +351,9 @@ class TunnelProtocol(asyncio.Protocol):
queryParams: typing.Optional[typing.Mapping[str, str]] = None,
) -> typing.MutableMapping[str, typing.Any]:
try:
url = (
cfg.uds_server + '/' + ticket.decode() + '/' + msg + '/' + cfg.uds_token
)
url = cfg.uds_server + '/' + ticket.decode() + '/' + msg + '/' + cfg.uds_token
if queryParams:
url += '?' + '&'.join(
[f'{key}={value}' for key, value in queryParams.items()]
)
url += '?' + '&'.join([f'{key}={value}' for key, value in queryParams.items()])
# Set options
options: typing.Dict[str, typing.Any] = {'timeout': cfg.uds_timeout}
if cfg.uds_verify_ssl is False:
@ -378,24 +366,15 @@ class TunnelProtocol(asyncio.Protocol):
raise Exception(await r.text())
return await r.json()
except Exception as e:
raise Exception(f'TICKET COMMS ERROR: {ticket.decode()} {msg} {e!s}')
raise Exception(f'TICKET COMMS ERROR: {ticket.decode()} {msg} {e!s}') from e
@staticmethod
async def get_ticket_from_uds(
cfg: config.ConfigurationType, ticket: bytes, address: typing.Tuple[str, int]
) -> typing.MutableMapping[str, typing.Any]:
# Sanity checks
if len(ticket) != consts.TICKET_LENGTH:
raise ValueError(f'TICKET INVALID (len={len(ticket)})')
for n, i in enumerate(ticket.decode(errors='ignore')):
if (
(i >= 'a' and i <= 'z')
or (i >= '0' and i <= '9')
or (i >= 'A' and i <= 'Z')
):
continue # Correctus
raise ValueError(f'TICKET INVALID (char {i} at pos {n})')
# Check ticket using re
if consts.TICKET_REGEX.match(ticket.decode(errors='replace')) is None:
raise ValueError(f'TICKET INVALID ({ticket.decode(errors="replace")})')
return await TunnelProtocol._read_from_uds(cfg, ticket, address[0])