From 2c77d361d7d50c0e30a355dd05e348185952e85b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adolfo=20G=C3=B3mez=20Garc=C3=ADa?= Date: Sun, 21 May 2023 16:23:18 +0200 Subject: [PATCH 1/2] backported 4.0 version improvements --- tunnel-server/src/uds_tunnel/proxy.py | 13 ++- tunnel-server/src/uds_tunnel/tunnel.py | 11 ++- tunnel-server/src/udstunnel.py | 109 ++++++++++++++----------- 3 files changed, 79 insertions(+), 54 deletions(-) diff --git a/tunnel-server/src/uds_tunnel/proxy.py b/tunnel-server/src/uds_tunnel/proxy.py index 93abf28bd..bfc012e7e 100644 --- a/tunnel-server/src/uds_tunnel/proxy.py +++ b/tunnel-server/src/uds_tunnel/proxy.py @@ -74,11 +74,17 @@ class Proxy: # Handshake correct in this point, upgrade the connection to TSL and let # the protocol controller do the rest + # Store source ip and port, for logging purposes in case of error + src_ip, src_port = (source.getpeername() if source else ('Unknown', 0))[:2] # May be ipv4 or ipv6, so we get only first two elements + # Upgrade connection to SSL, and use asyncio to handle the rest + tun: typing.Optional[tunnel.TunnelProtocol] = None try: + tun = tunnel.TunnelProtocol(self) # (connect accepted loop not present on AbastractEventLoop definition < 3.10), that's why we use ignore await loop.connect_accepted_socket( # type: ignore - lambda: tunnel.TunnelProtocol(self), source, ssl=context + lambda: tun, source, ssl=context, + ssl_handshake_timeout=3, ) # Wait for connection to be closed @@ -86,6 +92,11 @@ class Proxy: except asyncio.CancelledError: pass # Return on cancel + except Exception as e: + # Any other exception, ensure we close the connection + logger.error('ERROR on %s:%s: %s', src_ip, src_port, e) + if tun: + tun.close_connection() logger.debug('Proxy finished') diff --git a/tunnel-server/src/uds_tunnel/tunnel.py b/tunnel-server/src/uds_tunnel/tunnel.py index 209bd0098..93baa0fb2 100644 --- a/tunnel-server/src/uds_tunnel/tunnel.py +++ b/tunnel-server/src/uds_tunnel/tunnel.py @@ -97,8 +97,6 @@ class TunnelProtocol(asyncio.Protocol): # We start processing command # After command, we can process stats or do_proxy, that is the "normal" operation self.runner = self.do_command - # Set starting timeout task, se we dont get hunged on connections without data (or insufficient data) - self.set_timeout(self.owner.cfg.command_timeout) def process_open(self) -> None: # Open Command has the ticket behind it @@ -158,7 +156,7 @@ class TunnelProtocol(asyncio.Protocol): self.transport.write(b'OK') self.stats_manager.increment_connections() # Increment connections counters except Exception as e: - logger.error('Error opening connection: %s', e) + logger.error('CONNECTION FAILED: %s', e) self.close_connection() # add open other side to the loop @@ -276,9 +274,10 @@ class TunnelProtocol(asyncio.Protocol): def close_connection(self): try: + self.clean_timeout() # If a timeout is set, clean it if not self.transport.is_closing(): self.transport.close() - except Exception: + except Exception: # nosec: best effort pass # Ignore errors def notify_end(self): @@ -304,6 +303,10 @@ class TunnelProtocol(asyncio.Protocol): def connection_made(self, transport: 'asyncio.transports.BaseTransport') -> None: # We know for sure that the transport is a Transport. + + # Set starting timeout task, se we dont get hunged on connections without data (or insufficient data) + self.set_timeout(self.owner.cfg.command_timeout) + self.transport = typing.cast('asyncio.transports.Transport', transport) # Get source self.source = self.transport.get_extra_info('peername') diff --git a/tunnel-server/src/udstunnel.py b/tunnel-server/src/udstunnel.py index a5112c540..622492028 100755 --- a/tunnel-server/src/udstunnel.py +++ b/tunnel-server/src/udstunnel.py @@ -37,20 +37,22 @@ import argparse import signal import ssl import socket -import logging -from concurrent.futures import ThreadPoolExecutor -# event for stop notification -import threading +import threading # event for stop notification import typing +import logging +from logging.handlers import RotatingFileHandler +from concurrent.futures import ThreadPoolExecutor + try: - import uvloop + import uvloop # type: ignore + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) except ImportError: pass # no uvloop support try: - import setproctitle + import setproctitle # type: ignore except ImportError: setproctitle = None # type: ignore @@ -73,8 +75,6 @@ def stop_signal(signum: int, frame: typing.Any) -> None: def setup_log(cfg: config.ConfigurationType) -> None: - from logging.handlers import RotatingFileHandler - # Update logging if needed if cfg.logfile: fileh = RotatingFileHandler( @@ -96,9 +96,7 @@ def setup_log(cfg: config.ConfigurationType) -> None: log.setLevel(cfg.loglevel) handler = logging.StreamHandler(sys.stderr) handler.setLevel(cfg.loglevel) - formatter = logging.Formatter( - '%(levelname)s - %(message)s' - ) # Basic log format, nice for syslog + formatter = logging.Formatter('%(levelname)s - %(message)s') # Basic log format, nice for syslog handler.setFormatter(formatter) log.addHandler(handler) @@ -107,25 +105,25 @@ def setup_log(cfg: config.ConfigurationType) -> None: logger.debug('Configuration: %s', cfg) -async def tunnel_proc_async( - pipe: 'Connection', cfg: config.ConfigurationType, ns: 'Namespace' -) -> None: - +async def tunnel_proc_async(pipe: 'Connection', cfg: config.ConfigurationType, ns: 'Namespace') -> None: loop = asyncio.get_running_loop() tasks: typing.List[asyncio.Task] = [] def add_autoremovable_task(task: asyncio.Task) -> None: tasks.append(task) - task.add_done_callback(tasks.remove) + + def remove_task(task: asyncio.Task) -> None: + logger.debug('Removing task %s', task) + tasks.remove(task) + + task.add_done_callback(remove_task) def get_socket() -> typing.Tuple[typing.Optional[socket.socket], typing.Optional[typing.Tuple[str, int]]]: try: while True: # Clear back event, for next data - msg: typing.Optional[ - typing.Tuple[socket.socket, typing.Tuple[str, int]] - ] = pipe.recv() + msg: typing.Optional[typing.Tuple[socket.socket, typing.Tuple[str, int]]] = pipe.recv() if msg: return msg except EOFError: @@ -147,19 +145,31 @@ async def tunnel_proc_async( args['keyfile'] = cfg.ssl_certificate_key if cfg.ssl_password: args['password'] = cfg.ssl_password - + context.load_cert_chain(**args) # Set min version from string (1.2 or 1.3) as ssl.TLSVersion.TLSv1_2 or ssl.TLSVersion.TLSv1_3 if cfg.ssl_min_tls_version in ('1.2', '1.3'): - context.minimum_version = getattr(ssl.TLSVersion, f'TLSv1_{cfg.ssl_min_tls_version.split(".")[1]}') + try: + context.minimum_version = getattr( + ssl.TLSVersion, f'TLSv1_{cfg.ssl_min_tls_version.split(".")[1]}' + ) + except Exception as e: + logger.exception('Setting min tls version failed: %s. Using defaults', e) + context.minimum_version = ssl.TLSVersion.TLSv1_2 # Any other value will be ignored if cfg.ssl_ciphers: - context.set_ciphers(cfg.ssl_ciphers) + try: + context.set_ciphers(cfg.ssl_ciphers) + except Exception as e: + logger.exception('Setting ciphers failed: %s. Using defaults', e) if cfg.ssl_dhparam: - context.load_dh_params(cfg.ssl_dhparam) + try: + context.load_dh_params(cfg.ssl_dhparam) + except Exception as e: + logger.exception('Loading dhparams failed: %s. Using defaults', e) try: while True: @@ -168,28 +178,34 @@ async def tunnel_proc_async( (sock, address) = await loop.run_in_executor(None, get_socket) if not sock: break # No more sockets, exit - logger.debug(f'CONNECTION from {address!r} (pid: {os.getpid()})') + logger.debug('CONNECTION from %s (pid: %s)', address, os.getpid()) # Due to proxy contains an "event" to stop, we need to create a new one for each connection - add_autoremovable_task(asyncio.create_task(proxy.Proxy(cfg, ns)(sock, context))) - except asyncio.CancelledError: - raise + add_autoremovable_task( + asyncio.create_task(proxy.Proxy(cfg, ns)(sock, context), name=f'proxy-{address}') + ) + except asyncio.CancelledError: # pylint: disable=try-except-raise + raise # Stop, but avoid generic exception except Exception: logger.error('NEGOTIATION ERROR from %s', address[0] if address else 'unknown') except asyncio.CancelledError: pass # Stop # create task for server - - add_autoremovable_task(asyncio.create_task(run_server())) + + add_autoremovable_task(asyncio.create_task(run_server(), name='server')) try: while tasks and not do_stop.is_set(): to_wait = tasks[:] # Get a copy of the list # Wait for "to_wait" tasks to finish, stop every 2 seconds to check if we need to stop - done, _ = await asyncio.wait(to_wait, return_when=asyncio.FIRST_COMPLETED, timeout=2) + # done, _ = + await asyncio.wait(to_wait, return_when=asyncio.FIRST_COMPLETED, timeout=2) except asyncio.CancelledError: logger.info('Task cancelled') do_stop.set() # ensure we stop + except Exception: + logger.exception('Error in main loop') + do_stop.set() logger.debug('Out of loop, stopping tasks: %s, running: %s', tasks, do_stop.is_set()) @@ -199,7 +215,7 @@ async def tunnel_proc_async( task.cancel() except asyncio.CancelledError: pass # Ignore, we are stopping - + # for task in tasks: # task.cancel() @@ -208,16 +224,15 @@ async def tunnel_proc_async( logger.info('PROCESS %s stopped', os.getpid()) -def process_connection( - client: socket.socket, addr: typing.Tuple[str, str], conn: 'Connection' -) -> None: + +def process_connection(client: socket.socket, addr: typing.Tuple[str, str], conn: 'Connection') -> None: data: bytes = b'' try: # First, ensure handshake (simple handshake) and command data = client.recv(len(consts.HANDSHAKE_V1)) if data != consts.HANDSHAKE_V1: - raise Exception('Invalid data from {}: {}'.format(addr[0], data.hex())) # Invalid handshake + raise Exception(f'Invalid data from {addr[0]}: {data.hex()}') # Invalid handshake conn.send((client, addr)) del client # Ensure socket is controlled on child process except Exception as e: @@ -231,11 +246,11 @@ def tunnel_main(args: 'argparse.Namespace') -> None: # Try to bind to port as running user # Wait for socket incoming connections and spread them - socket.setdefaulttimeout( - 3.0 - ) # So we can check for stop from time to time and not block forever - af_inet = socket.AF_INET6 if args.ipv6 or cfg.ipv6 or ':' in cfg.listen_address else socket.AF_INET - sock = socket.socket(af_inet, socket.SOCK_STREAM) + socket.setdefaulttimeout(3.0) # So we can check for stop from time to time and not block forever + sock = socket.socket( + socket.AF_INET6 if args.ipv6 or cfg.ipv6 or ':' in cfg.listen_address else socket.AF_INET, + socket.SOCK_STREAM, + ) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True) sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) # We will not reuse port, we only want a UDS tunnel server running on a port @@ -257,15 +272,13 @@ def tunnel_main(args: 'argparse.Namespace') -> None: setup_log(cfg) - logger.info( - 'Starting tunnel server on %s:%s', cfg.listen_address, cfg.listen_port - ) + logger.info('Starting tunnel server on %s:%s', cfg.listen_address, cfg.listen_port) if setproctitle: setproctitle.setproctitle(f'UDSTunnel {cfg.listen_address}:{cfg.listen_port}') # Create pid file if cfg.pidfile: - with open(cfg.pidfile, mode='w') as f: + with open(cfg.pidfile, mode='w', encoding='utf-8') as f: f.write(str(os.getpid())) except Exception as e: @@ -278,7 +291,7 @@ def tunnel_main(args: 'argparse.Namespace') -> None: signal.signal(signal.SIGINT, stop_signal) signal.signal(signal.SIGTERM, stop_signal) except Exception as e: - # Signal not available on threads, and we use threads on tests, + # Signal not available on threads, and we use threads on tests, # so we will ignore this because on tests signals are not important logger.warning('Signal not available: %s', e) @@ -317,7 +330,7 @@ def tunnel_main(args: 'argparse.Namespace') -> None: if cfg.pidfile: os.unlink(cfg.pidfile) except Exception: - pass + logger.warning('Could not remove pidfile %s', cfg.pidfile) logger.info('FINISHED') @@ -325,9 +338,7 @@ def tunnel_main(args: 'argparse.Namespace') -> None: def main() -> None: parser = argparse.ArgumentParser() group = parser.add_mutually_exclusive_group() - group.add_argument( - '-t', '--tunnel', help='Starts the tunnel server', action='store_true' - ) + group.add_argument('-t', '--tunnel', help='Starts the tunnel server', action='store_true') # group.add_argument('-r', '--rdp', help='RDP Tunnel for traffic accounting') group.add_argument( '-s', From 084e0cc2a08bcfba6293692a0ff5f6fe6e9e42f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adolfo=20G=C3=B3mez=20Garc=C3=ADa?= Date: Sun, 21 May 2023 16:49:04 +0200 Subject: [PATCH 2/2] fixed bad ssl handshake management --- tunnel-server/src/uds_tunnel/proxy.py | 3 +++ tunnel-server/src/uds_tunnel/tunnel.py | 8 +++++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/tunnel-server/src/uds_tunnel/proxy.py b/tunnel-server/src/uds_tunnel/proxy.py index bfc012e7e..6648aa4ac 100644 --- a/tunnel-server/src/uds_tunnel/proxy.py +++ b/tunnel-server/src/uds_tunnel/proxy.py @@ -97,6 +97,9 @@ class Proxy: logger.error('ERROR on %s:%s: %s', src_ip, src_port, e) if tun: tun.close_connection() + # Also, ensure socket is closed + if source: + del source logger.debug('Proxy finished') diff --git a/tunnel-server/src/uds_tunnel/tunnel.py b/tunnel-server/src/uds_tunnel/tunnel.py index 93baa0fb2..c4df5703e 100644 --- a/tunnel-server/src/uds_tunnel/tunnel.py +++ b/tunnel-server/src/uds_tunnel/tunnel.py @@ -275,10 +275,12 @@ class TunnelProtocol(asyncio.Protocol): def close_connection(self): try: self.clean_timeout() # If a timeout is set, clean it - if not self.transport.is_closing(): + if not self.transport.is_closing(): # Attribute may alreade not be set self.transport.close() - except Exception: # nosec: best effort - pass # Ignore errors + except AttributeError: # not initialized transport, fine... + pass + except Exception as e: # nosec: best effort + logger.error('ERROR closing connection: %s', e) def notify_end(self): if self.notify_ticket: