From 4d26df95804ae58fbfab88e91e4918337a5d3208 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adolfo=20G=C3=B3mez=20Garc=C3=ADa?= Date: Tue, 16 May 2023 01:09:32 +0200 Subject: [PATCH] upgrading and linting tunnel --- tunnel-server/src/uds_tunnel/consts.py | 4 ++ tunnel-server/src/uds_tunnel/tunnel.py | 57 ++++++++------------------ 2 files changed, 22 insertions(+), 39 deletions(-) diff --git a/tunnel-server/src/uds_tunnel/consts.py b/tunnel-server/src/uds_tunnel/consts.py index e9210efa5..87fb0c0a6 100644 --- a/tunnel-server/src/uds_tunnel/consts.py +++ b/tunnel-server/src/uds_tunnel/consts.py @@ -28,6 +28,7 @@ ''' Author: Adolfo Gómez, dkmaster at dkmon dot com ''' +import re import typing DEBUG = True @@ -71,3 +72,6 @@ RESPONSE_OK: typing.Final[bytes] = b'OK' # Backlog for listen socket BACKLOG = 1024 + +# Regular expression for parsing ticket +TICKET_REGEX = re.compile(f'^[a-zA-Z0-9]{{{TICKET_LENGTH}}}$') diff --git a/tunnel-server/src/uds_tunnel/tunnel.py b/tunnel-server/src/uds_tunnel/tunnel.py index 22f1d275c..209bd0098 100644 --- a/tunnel-server/src/uds_tunnel/tunnel.py +++ b/tunnel-server/src/uds_tunnel/tunnel.py @@ -78,9 +78,7 @@ class TunnelProtocol(asyncio.Protocol): # If there is a timeout task running timeout_task: typing.Optional[asyncio.Task] = None - def __init__( - self, owner: 'proxy.Proxy' - ) -> None: + def __init__(self, owner: 'proxy.Proxy') -> None: # If no other side is given, we are the server part super().__init__() # transport is undefined until connection_made is called @@ -91,7 +89,7 @@ class TunnelProtocol(asyncio.Protocol): self.destination = ('', 0) self.tls_version = '' self.tls_cipher = '' - + # If other_side is given, we are the client part (that is, the tunnel from us to remote machine) # In this case, only do_proxy is used self.client = None @@ -124,9 +122,7 @@ class TunnelProtocol(asyncio.Protocol): async def open_client() -> None: try: - result = await TunnelProtocol.get_ticket_from_uds( - self.owner.cfg, ticket, self.source - ) + result = await TunnelProtocol.get_ticket_from_uds(self.owner.cfg, ticket, self.source) except Exception as e: logger.error('ERROR %s', e.args[0] if e.args else e) self.transport.write(consts.RESPONSE_ERROR_TICKET) @@ -146,8 +142,7 @@ class TunnelProtocol(asyncio.Protocol): try: family = ( socket.AF_INET6 - if ':' in self.destination[0] - or (self.owner.cfg.ipv6 and not '.' in self.destination[0]) + if ':' in self.destination[0] or (self.owner.cfg.ipv6 and '.' not in self.destination[0]) else socket.AF_INET ) (_, self.client) = await loop.create_connection( @@ -161,7 +156,7 @@ class TunnelProtocol(asyncio.Protocol): self.transport.resume_reading() # send OK to client self.transport.write(b'OK') - self.stats_manager.increment_connections() # Increment connections counters + self.stats_manager.increment_connections() # Increment connections counters except Exception as e: logger.error('Error opening connection: %s', e) self.close_connection() @@ -171,7 +166,7 @@ class TunnelProtocol(asyncio.Protocol): # From now, proxy connection self.runner = self.do_proxy - def process_stats(self, full: bool) -> None: + def process_stats(self, full: bool) -> None: # pylint: disable=unused-argument # if pasword is not already received, wait for it if len(self.cmd) < consts.PASSWORD_LENGTH + consts.COMMAND_LENGTH: return @@ -246,22 +241,22 @@ class TunnelProtocol(asyncio.Protocol): try: if command == consts.COMMAND_OPEN: self.process_open() - elif command == consts.COMMAND_TEST: + return + if command == consts.COMMAND_TEST: self.clean_timeout() # Stop timeout logger.info('COMMAND: TEST') self.transport.write(consts.RESPONSE_OK) self.close_connection() return - elif command in (consts.COMMAND_STAT, consts.COMMAND_INFO): + if command in (consts.COMMAND_STAT, consts.COMMAND_INFO): # This is an stats requests try: - self.process_stats(full=command == consts.COMMAND_STAT) + self.process_stats(full=command == consts.COMMAND_STAT) except Exception as e: logger.error('ERROR processing stats: %s', e.args[0] if e.args else e) self.close_connection() return - else: - raise Exception('Invalid command') + raise Exception('Invalid command') except Exception: logger.error('ERROR from %s', self.pretty_source()) self.transport.write(consts.RESPONSE_ERROR_COMMAND) @@ -298,9 +293,7 @@ class TunnelProtocol(asyncio.Protocol): ) # Notify end to uds, using a task becase we are not an async function asyncio.get_event_loop().create_task( - TunnelProtocol.notify_end_to_uds( - self.owner.cfg, self.notify_ticket, self.stats_manager - ) + TunnelProtocol.notify_end_to_uds(self.owner.cfg, self.notify_ticket, self.stats_manager) ) self.notify_ticket = b'' # Clean up so no more notifications else: @@ -350,7 +343,6 @@ class TunnelProtocol(asyncio.Protocol): def pretty_destination(self) -> str: return TunnelProtocol.pretty_address(self.destination) - @staticmethod async def _read_from_uds( cfg: config.ConfigurationType, @@ -359,13 +351,9 @@ class TunnelProtocol(asyncio.Protocol): queryParams: typing.Optional[typing.Mapping[str, str]] = None, ) -> typing.MutableMapping[str, typing.Any]: try: - url = ( - cfg.uds_server + '/' + ticket.decode() + '/' + msg + '/' + cfg.uds_token - ) + url = cfg.uds_server + '/' + ticket.decode() + '/' + msg + '/' + cfg.uds_token if queryParams: - url += '?' + '&'.join( - [f'{key}={value}' for key, value in queryParams.items()] - ) + url += '?' + '&'.join([f'{key}={value}' for key, value in queryParams.items()]) # Set options options: typing.Dict[str, typing.Any] = {'timeout': cfg.uds_timeout} if cfg.uds_verify_ssl is False: @@ -378,24 +366,15 @@ class TunnelProtocol(asyncio.Protocol): raise Exception(await r.text()) return await r.json() except Exception as e: - raise Exception(f'TICKET COMMS ERROR: {ticket.decode()} {msg} {e!s}') + raise Exception(f'TICKET COMMS ERROR: {ticket.decode()} {msg} {e!s}') from e @staticmethod async def get_ticket_from_uds( cfg: config.ConfigurationType, ticket: bytes, address: typing.Tuple[str, int] ) -> typing.MutableMapping[str, typing.Any]: - # Sanity checks - if len(ticket) != consts.TICKET_LENGTH: - raise ValueError(f'TICKET INVALID (len={len(ticket)})') - - for n, i in enumerate(ticket.decode(errors='ignore')): - if ( - (i >= 'a' and i <= 'z') - or (i >= '0' and i <= '9') - or (i >= 'A' and i <= 'Z') - ): - continue # Correctus - raise ValueError(f'TICKET INVALID (char {i} at pos {n})') + # Check ticket using re + if consts.TICKET_REGEX.match(ticket.decode(errors='replace')) is None: + raise ValueError(f'TICKET INVALID ({ticket.decode(errors="replace")})') return await TunnelProtocol._read_from_uds(cfg, ticket, address[0])