diff --git a/tunnel-server/src/uds_tunnel/config.py b/tunnel-server/src/uds_tunnel/config.py index 11673114d..da79f9c87 100644 --- a/tunnel-server/src/uds_tunnel/config.py +++ b/tunnel-server/src/uds_tunnel/config.py @@ -66,11 +66,20 @@ class ConfigurationType(typing.NamedTuple): uds_timeout: int uds_verify_ssl: bool + command_timeout: int + secret: str allow: typing.Set[str] use_uvloop: bool + def __str__(self) -> str: + return 'Configuration: \n' + '\n'.join( + f'{k}={v}' + for k, v in self._asdict().items() + ) + + def read_config_file( cfg_file: typing.Optional[typing.Union[typing.TextIO, str]] = None @@ -131,6 +140,7 @@ def read( uds_token=uds.get('uds_token', 'unauthorized'), uds_timeout=int(uds.get('uds_timeout', '10')), uds_verify_ssl=uds.get('uds_verify_ssl', 'true').lower() == 'true', + command_timeout=int(uds.get('command_timeout', '3')), secret=secret, allow=set(uds.get('allow', '127.0.0.1').split(',')), use_uvloop=uds.get('use_uvloop', 'true').lower() == 'true', diff --git a/tunnel-server/src/uds_tunnel/consts.py b/tunnel-server/src/uds_tunnel/consts.py index 93f7649d7..e9210efa5 100644 --- a/tunnel-server/src/uds_tunnel/consts.py +++ b/tunnel-server/src/uds_tunnel/consts.py @@ -69,8 +69,5 @@ RESPONSE_FORBIDDEN: typing.Final[bytes] = b'FORBIDDEN' RESPONSE_OK: typing.Final[bytes] = b'OK' -# Timeout for command -TIMEOUT_COMMAND: typing.Final[int] = 3 - # Backlog for listen socket BACKLOG = 1024 diff --git a/tunnel-server/src/uds_tunnel/tunnel.py b/tunnel-server/src/uds_tunnel/tunnel.py index 94d12a632..c8336901b 100644 --- a/tunnel-server/src/uds_tunnel/tunnel.py +++ b/tunnel-server/src/uds_tunnel/tunnel.py @@ -73,6 +73,13 @@ class TunnelProtocol(asyncio.Protocol): ) -> None: # If no other side is given, we are the server part super().__init__() + # transport is undefined until connection_made is called + self.cmd = b'' + self.notify_ticket = b'' + self.owner = owner + self.source = ('', 0) + self.destination = ('', 0) + if other_side: self.other_side = other_side self.stats_manager = other_side.stats_manager @@ -84,21 +91,15 @@ class TunnelProtocol(asyncio.Protocol): self.counter = self.stats_manager.as_sent_counter() self.runner = self.do_command # Set starting timeout task, se we dont get hunged on connections without data - self.set_timeout(consts.TIMEOUT_COMMAND) + self.set_timeout(self.owner.cfg.command_timeout) - # transport is undefined until connection_made is called - self.cmd = b'' - self.notify_ticket = b'' - self.owner = owner - self.source = ('', 0) - self.destination = ('', 0) def process_open(self) -> None: # Open Command has the ticket behind it if len(self.cmd) < consts.TICKET_LENGTH + consts.COMMAND_LENGTH: # Reactivate timeout, will be deactivated on do_command - self.set_timeout(consts.TIMEOUT_COMMAND) + self.set_timeout(self.owner.cfg.command_timeout) return # Wait for more data to complete OPEN command # Ticket received, now process it with UDS @@ -253,7 +254,7 @@ class TunnelProtocol(asyncio.Protocol): self.close_connection() return else: - self.set_timeout(consts.TIMEOUT_COMMAND) + self.set_timeout(self.owner.cfg.command_timeout) # if not enough data to process command, wait for more @@ -289,7 +290,6 @@ class TunnelProtocol(asyncio.Protocol): self.owner.finished.set() def connection_lost(self, exc: typing.Optional[Exception]) -> None: - logger.debug('Connection closed : %s', exc) # Ensure close other side if any if self.other_side is not self: self.other_side.transport.close() diff --git a/tunnel-server/src/udstunnel.conf b/tunnel-server/src/udstunnel.conf index 87f1351bf..3a307d396 100644 --- a/tunnel-server/src/udstunnel.conf +++ b/tunnel-server/src/udstunnel.conf @@ -53,6 +53,10 @@ uds_token = eBCeFxTBw1IKXCqq-RlncshwWIfrrqxc8y5nehqiqMtRztwD # If verify ssl certificate on uds server. Defaults to true # uds_verify_ssl = true +# Command timeout. Command reception on tunnel will timeout after this time (in seconds) +# defaults to 3 seconds +# command_timeout = 3 + # Secret to get access to admin commands (Currently only stats commands). No default for this. # Admin commands and only allowed from "allow" ips # So, in order to allow this commands, ensure listen address allows connections from localhost diff --git a/tunnel-server/src/udstunnel.py b/tunnel-server/src/udstunnel.py index 8240f8e71..1704fa8fd 100755 --- a/tunnel-server/src/udstunnel.py +++ b/tunnel-server/src/udstunnel.py @@ -102,6 +102,10 @@ def setup_log(cfg: config.ConfigurationType) -> None: handler.setFormatter(formatter) log.addHandler(handler) + # If debug, print config + if cfg.loglevel.lower() == 'debug': + logger.debug('Configuration: %s', cfg) + async def tunnel_proc_async( pipe: 'Connection', cfg: config.ConfigurationType, ns: 'Namespace' @@ -111,6 +115,10 @@ async def tunnel_proc_async( tasks: typing.List[asyncio.Task] = [] + def add_autoremovable_task(task: asyncio.Task) -> None: + tasks.append(task) + task.add_done_callback(tasks.remove) + def get_socket() -> typing.Tuple[typing.Optional[socket.socket], typing.Optional[typing.Tuple[str, int]]]: try: while True: @@ -157,7 +165,7 @@ async def tunnel_proc_async( break # No more sockets, exit logger.debug(f'CONNECTION from {address!r} (pid: {os.getpid()})') # Due to proxy contains an "event" to stop, we need to create a new one for each connection - tasks.append(asyncio.create_task(proxy.Proxy(cfg, ns)(sock, context))) + add_autoremovable_task(asyncio.create_task(proxy.Proxy(cfg, ns)(sock, context))) except asyncio.CancelledError: raise except Exception: @@ -166,23 +174,20 @@ async def tunnel_proc_async( pass # Stop # create task for server - tasks.append(asyncio.create_task(run_server())) + + add_autoremovable_task(asyncio.create_task(run_server())) try: while tasks and not do_stop.is_set(): to_wait = tasks[:] # Get a copy of the list # Wait for "to_wait" tasks to finish, stop every 2 seconds to check if we need to stop done, _ = await asyncio.wait(to_wait, return_when=asyncio.FIRST_COMPLETED, timeout=2) - # Remove finished tasks - for task in done: - tasks.remove(task) - if task.exception(): - logger.exception('TUNNEL ERROR') except asyncio.CancelledError: logger.info('Task cancelled') do_stop.set() # ensure we stop logger.debug('Out of loop, stopping tasks: %s, running: %s', tasks, do_stop.is_set()) + # If any task is still running, cancel it for task in tasks: task.cancel()