1
0
mirror of https://github.com/dkmstr/openuds.git synced 2025-01-10 01:17:59 +03:00

upgrading and linting tunnel

This commit is contained in:
Adolfo Gómez García 2023-05-16 01:09:32 +02:00
parent ddf07eb68b
commit 4d26df9580
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
2 changed files with 22 additions and 39 deletions

View File

@ -28,6 +28,7 @@
''' '''
Author: Adolfo Gómez, dkmaster at dkmon dot com Author: Adolfo Gómez, dkmaster at dkmon dot com
''' '''
import re
import typing import typing
DEBUG = True DEBUG = True
@ -71,3 +72,6 @@ RESPONSE_OK: typing.Final[bytes] = b'OK'
# Backlog for listen socket # Backlog for listen socket
BACKLOG = 1024 BACKLOG = 1024
# Regular expression for parsing ticket
TICKET_REGEX = re.compile(f'^[a-zA-Z0-9]{{{TICKET_LENGTH}}}$')

View File

@ -78,9 +78,7 @@ class TunnelProtocol(asyncio.Protocol):
# If there is a timeout task running # If there is a timeout task running
timeout_task: typing.Optional[asyncio.Task] = None timeout_task: typing.Optional[asyncio.Task] = None
def __init__( def __init__(self, owner: 'proxy.Proxy') -> None:
self, owner: 'proxy.Proxy'
) -> None:
# If no other side is given, we are the server part # If no other side is given, we are the server part
super().__init__() super().__init__()
# transport is undefined until connection_made is called # transport is undefined until connection_made is called
@ -91,7 +89,7 @@ class TunnelProtocol(asyncio.Protocol):
self.destination = ('', 0) self.destination = ('', 0)
self.tls_version = '' self.tls_version = ''
self.tls_cipher = '' self.tls_cipher = ''
# If other_side is given, we are the client part (that is, the tunnel from us to remote machine) # If other_side is given, we are the client part (that is, the tunnel from us to remote machine)
# In this case, only do_proxy is used # In this case, only do_proxy is used
self.client = None self.client = None
@ -124,9 +122,7 @@ class TunnelProtocol(asyncio.Protocol):
async def open_client() -> None: async def open_client() -> None:
try: try:
result = await TunnelProtocol.get_ticket_from_uds( result = await TunnelProtocol.get_ticket_from_uds(self.owner.cfg, ticket, self.source)
self.owner.cfg, ticket, self.source
)
except Exception as e: except Exception as e:
logger.error('ERROR %s', e.args[0] if e.args else e) logger.error('ERROR %s', e.args[0] if e.args else e)
self.transport.write(consts.RESPONSE_ERROR_TICKET) self.transport.write(consts.RESPONSE_ERROR_TICKET)
@ -146,8 +142,7 @@ class TunnelProtocol(asyncio.Protocol):
try: try:
family = ( family = (
socket.AF_INET6 socket.AF_INET6
if ':' in self.destination[0] if ':' in self.destination[0] or (self.owner.cfg.ipv6 and '.' not in self.destination[0])
or (self.owner.cfg.ipv6 and not '.' in self.destination[0])
else socket.AF_INET else socket.AF_INET
) )
(_, self.client) = await loop.create_connection( (_, self.client) = await loop.create_connection(
@ -161,7 +156,7 @@ class TunnelProtocol(asyncio.Protocol):
self.transport.resume_reading() self.transport.resume_reading()
# send OK to client # send OK to client
self.transport.write(b'OK') self.transport.write(b'OK')
self.stats_manager.increment_connections() # Increment connections counters self.stats_manager.increment_connections() # Increment connections counters
except Exception as e: except Exception as e:
logger.error('Error opening connection: %s', e) logger.error('Error opening connection: %s', e)
self.close_connection() self.close_connection()
@ -171,7 +166,7 @@ class TunnelProtocol(asyncio.Protocol):
# From now, proxy connection # From now, proxy connection
self.runner = self.do_proxy self.runner = self.do_proxy
def process_stats(self, full: bool) -> None: def process_stats(self, full: bool) -> None: # pylint: disable=unused-argument
# if pasword is not already received, wait for it # if pasword is not already received, wait for it
if len(self.cmd) < consts.PASSWORD_LENGTH + consts.COMMAND_LENGTH: if len(self.cmd) < consts.PASSWORD_LENGTH + consts.COMMAND_LENGTH:
return return
@ -246,22 +241,22 @@ class TunnelProtocol(asyncio.Protocol):
try: try:
if command == consts.COMMAND_OPEN: if command == consts.COMMAND_OPEN:
self.process_open() self.process_open()
elif command == consts.COMMAND_TEST: return
if command == consts.COMMAND_TEST:
self.clean_timeout() # Stop timeout self.clean_timeout() # Stop timeout
logger.info('COMMAND: TEST') logger.info('COMMAND: TEST')
self.transport.write(consts.RESPONSE_OK) self.transport.write(consts.RESPONSE_OK)
self.close_connection() self.close_connection()
return return
elif command in (consts.COMMAND_STAT, consts.COMMAND_INFO): if command in (consts.COMMAND_STAT, consts.COMMAND_INFO):
# This is an stats requests # This is an stats requests
try: try:
self.process_stats(full=command == consts.COMMAND_STAT) self.process_stats(full=command == consts.COMMAND_STAT)
except Exception as e: except Exception as e:
logger.error('ERROR processing stats: %s', e.args[0] if e.args else e) logger.error('ERROR processing stats: %s', e.args[0] if e.args else e)
self.close_connection() self.close_connection()
return return
else: raise Exception('Invalid command')
raise Exception('Invalid command')
except Exception: except Exception:
logger.error('ERROR from %s', self.pretty_source()) logger.error('ERROR from %s', self.pretty_source())
self.transport.write(consts.RESPONSE_ERROR_COMMAND) self.transport.write(consts.RESPONSE_ERROR_COMMAND)
@ -298,9 +293,7 @@ class TunnelProtocol(asyncio.Protocol):
) )
# Notify end to uds, using a task becase we are not an async function # Notify end to uds, using a task becase we are not an async function
asyncio.get_event_loop().create_task( asyncio.get_event_loop().create_task(
TunnelProtocol.notify_end_to_uds( TunnelProtocol.notify_end_to_uds(self.owner.cfg, self.notify_ticket, self.stats_manager)
self.owner.cfg, self.notify_ticket, self.stats_manager
)
) )
self.notify_ticket = b'' # Clean up so no more notifications self.notify_ticket = b'' # Clean up so no more notifications
else: else:
@ -350,7 +343,6 @@ class TunnelProtocol(asyncio.Protocol):
def pretty_destination(self) -> str: def pretty_destination(self) -> str:
return TunnelProtocol.pretty_address(self.destination) return TunnelProtocol.pretty_address(self.destination)
@staticmethod @staticmethod
async def _read_from_uds( async def _read_from_uds(
cfg: config.ConfigurationType, cfg: config.ConfigurationType,
@ -359,13 +351,9 @@ class TunnelProtocol(asyncio.Protocol):
queryParams: typing.Optional[typing.Mapping[str, str]] = None, queryParams: typing.Optional[typing.Mapping[str, str]] = None,
) -> typing.MutableMapping[str, typing.Any]: ) -> typing.MutableMapping[str, typing.Any]:
try: try:
url = ( url = cfg.uds_server + '/' + ticket.decode() + '/' + msg + '/' + cfg.uds_token
cfg.uds_server + '/' + ticket.decode() + '/' + msg + '/' + cfg.uds_token
)
if queryParams: if queryParams:
url += '?' + '&'.join( url += '?' + '&'.join([f'{key}={value}' for key, value in queryParams.items()])
[f'{key}={value}' for key, value in queryParams.items()]
)
# Set options # Set options
options: typing.Dict[str, typing.Any] = {'timeout': cfg.uds_timeout} options: typing.Dict[str, typing.Any] = {'timeout': cfg.uds_timeout}
if cfg.uds_verify_ssl is False: if cfg.uds_verify_ssl is False:
@ -378,24 +366,15 @@ class TunnelProtocol(asyncio.Protocol):
raise Exception(await r.text()) raise Exception(await r.text())
return await r.json() return await r.json()
except Exception as e: except Exception as e:
raise Exception(f'TICKET COMMS ERROR: {ticket.decode()} {msg} {e!s}') raise Exception(f'TICKET COMMS ERROR: {ticket.decode()} {msg} {e!s}') from e
@staticmethod @staticmethod
async def get_ticket_from_uds( async def get_ticket_from_uds(
cfg: config.ConfigurationType, ticket: bytes, address: typing.Tuple[str, int] cfg: config.ConfigurationType, ticket: bytes, address: typing.Tuple[str, int]
) -> typing.MutableMapping[str, typing.Any]: ) -> typing.MutableMapping[str, typing.Any]:
# Sanity checks # Check ticket using re
if len(ticket) != consts.TICKET_LENGTH: if consts.TICKET_REGEX.match(ticket.decode(errors='replace')) is None:
raise ValueError(f'TICKET INVALID (len={len(ticket)})') raise ValueError(f'TICKET INVALID ({ticket.decode(errors="replace")})')
for n, i in enumerate(ticket.decode(errors='ignore')):
if (
(i >= 'a' and i <= 'z')
or (i >= '0' and i <= '9')
or (i >= 'A' and i <= 'Z')
):
continue # Correctus
raise ValueError(f'TICKET INVALID (char {i} at pos {n})')
return await TunnelProtocol._read_from_uds(cfg, ticket, address[0]) return await TunnelProtocol._read_from_uds(cfg, ticket, address[0])