diff --git a/tunnel-server/samples/async_upgrade_server.py b/tunnel-server/samples/async_upgrade_server.py index 7eb568fef..668f0aa4d 100755 --- a/tunnel-server/samples/async_upgrade_server.py +++ b/tunnel-server/samples/async_upgrade_server.py @@ -92,7 +92,6 @@ class TunnelProtocol(asyncio.Protocol): logger.error('Invalid state reached!') def connection_lost(self, exc: typing.Optional[Exception]) -> None: - logger.debug('Connection closed : %s', exc) self.finished.set_result(True) if self.other_side is not self: self.other_side.transport.close() diff --git a/tunnel-server/src/uds_tunnel/config.py b/tunnel-server/src/uds_tunnel/config.py index 11673114d..070714099 100644 --- a/tunnel-server/src/uds_tunnel/config.py +++ b/tunnel-server/src/uds_tunnel/config.py @@ -66,11 +66,20 @@ class ConfigurationType(typing.NamedTuple): uds_timeout: int uds_verify_ssl: bool + command_timeout: float + secret: str allow: typing.Set[str] use_uvloop: bool + def __str__(self) -> str: + return 'Configuration: \n' + '\n'.join( + f'{k}={v}' + for k, v in self._asdict().items() + ) + + def read_config_file( cfg_file: typing.Optional[typing.Union[typing.TextIO, str]] = None @@ -131,6 +140,7 @@ def read( uds_token=uds.get('uds_token', 'unauthorized'), uds_timeout=int(uds.get('uds_timeout', '10')), uds_verify_ssl=uds.get('uds_verify_ssl', 'true').lower() == 'true', + command_timeout=float(uds.get('command_timeout', '3')), secret=secret, allow=set(uds.get('allow', '127.0.0.1').split(',')), use_uvloop=uds.get('use_uvloop', 'true').lower() == 'true', diff --git a/tunnel-server/src/uds_tunnel/consts.py b/tunnel-server/src/uds_tunnel/consts.py index 93f7649d7..e9210efa5 100644 --- a/tunnel-server/src/uds_tunnel/consts.py +++ b/tunnel-server/src/uds_tunnel/consts.py @@ -69,8 +69,5 @@ RESPONSE_FORBIDDEN: typing.Final[bytes] = b'FORBIDDEN' RESPONSE_OK: typing.Final[bytes] = b'OK' -# Timeout for command -TIMEOUT_COMMAND: typing.Final[int] = 3 - # Backlog for listen socket BACKLOG = 1024 diff --git a/tunnel-server/src/uds_tunnel/tunnel.py b/tunnel-server/src/uds_tunnel/tunnel.py index 94d12a632..cbf5b1b10 100644 --- a/tunnel-server/src/uds_tunnel/tunnel.py +++ b/tunnel-server/src/uds_tunnel/tunnel.py @@ -73,6 +73,13 @@ class TunnelProtocol(asyncio.Protocol): ) -> None: # If no other side is given, we are the server part super().__init__() + # transport is undefined until connection_made is called + self.cmd = b'' + self.notify_ticket = b'' + self.owner = owner + self.source = ('', 0) + self.destination = ('', 0) + if other_side: self.other_side = other_side self.stats_manager = other_side.stats_manager @@ -84,21 +91,15 @@ class TunnelProtocol(asyncio.Protocol): self.counter = self.stats_manager.as_sent_counter() self.runner = self.do_command # Set starting timeout task, se we dont get hunged on connections without data - self.set_timeout(consts.TIMEOUT_COMMAND) + self.set_timeout(self.owner.cfg.command_timeout) - # transport is undefined until connection_made is called - self.cmd = b'' - self.notify_ticket = b'' - self.owner = owner - self.source = ('', 0) - self.destination = ('', 0) def process_open(self) -> None: # Open Command has the ticket behind it if len(self.cmd) < consts.TICKET_LENGTH + consts.COMMAND_LENGTH: # Reactivate timeout, will be deactivated on do_command - self.set_timeout(consts.TIMEOUT_COMMAND) + self.set_timeout(self.owner.cfg.command_timeout) return # Wait for more data to complete OPEN command # Ticket received, now process it with UDS @@ -196,7 +197,7 @@ class TunnelProtocol(asyncio.Protocol): finally: self.close_connection() - async def timeout(self, wait: int) -> None: + async def timeout(self, wait: float) -> None: """Timeout can only occur while waiting for a command (or OPEN command ticket).""" try: await asyncio.sleep(wait) @@ -206,7 +207,7 @@ class TunnelProtocol(asyncio.Protocol): except asyncio.CancelledError: pass - def set_timeout(self, wait: int) -> None: + def set_timeout(self, wait: float) -> None: """Set a timeout for this connection. If reached, the connection will be closed. @@ -253,7 +254,7 @@ class TunnelProtocol(asyncio.Protocol): self.close_connection() return else: - self.set_timeout(consts.TIMEOUT_COMMAND) + self.set_timeout(self.owner.cfg.command_timeout) # if not enough data to process command, wait for more @@ -289,7 +290,6 @@ class TunnelProtocol(asyncio.Protocol): self.owner.finished.set() def connection_lost(self, exc: typing.Optional[Exception]) -> None: - logger.debug('Connection closed : %s', exc) # Ensure close other side if any if self.other_side is not self: self.other_side.transport.close() diff --git a/tunnel-server/src/udstunnel.conf b/tunnel-server/src/udstunnel.conf index 87f1351bf..3a307d396 100644 --- a/tunnel-server/src/udstunnel.conf +++ b/tunnel-server/src/udstunnel.conf @@ -53,6 +53,10 @@ uds_token = eBCeFxTBw1IKXCqq-RlncshwWIfrrqxc8y5nehqiqMtRztwD # If verify ssl certificate on uds server. Defaults to true # uds_verify_ssl = true +# Command timeout. Command reception on tunnel will timeout after this time (in seconds) +# defaults to 3 seconds +# command_timeout = 3 + # Secret to get access to admin commands (Currently only stats commands). No default for this. # Admin commands and only allowed from "allow" ips # So, in order to allow this commands, ensure listen address allows connections from localhost diff --git a/tunnel-server/src/udstunnel.py b/tunnel-server/src/udstunnel.py index 8240f8e71..1704fa8fd 100755 --- a/tunnel-server/src/udstunnel.py +++ b/tunnel-server/src/udstunnel.py @@ -102,6 +102,10 @@ def setup_log(cfg: config.ConfigurationType) -> None: handler.setFormatter(formatter) log.addHandler(handler) + # If debug, print config + if cfg.loglevel.lower() == 'debug': + logger.debug('Configuration: %s', cfg) + async def tunnel_proc_async( pipe: 'Connection', cfg: config.ConfigurationType, ns: 'Namespace' @@ -111,6 +115,10 @@ async def tunnel_proc_async( tasks: typing.List[asyncio.Task] = [] + def add_autoremovable_task(task: asyncio.Task) -> None: + tasks.append(task) + task.add_done_callback(tasks.remove) + def get_socket() -> typing.Tuple[typing.Optional[socket.socket], typing.Optional[typing.Tuple[str, int]]]: try: while True: @@ -157,7 +165,7 @@ async def tunnel_proc_async( break # No more sockets, exit logger.debug(f'CONNECTION from {address!r} (pid: {os.getpid()})') # Due to proxy contains an "event" to stop, we need to create a new one for each connection - tasks.append(asyncio.create_task(proxy.Proxy(cfg, ns)(sock, context))) + add_autoremovable_task(asyncio.create_task(proxy.Proxy(cfg, ns)(sock, context))) except asyncio.CancelledError: raise except Exception: @@ -166,23 +174,20 @@ async def tunnel_proc_async( pass # Stop # create task for server - tasks.append(asyncio.create_task(run_server())) + + add_autoremovable_task(asyncio.create_task(run_server())) try: while tasks and not do_stop.is_set(): to_wait = tasks[:] # Get a copy of the list # Wait for "to_wait" tasks to finish, stop every 2 seconds to check if we need to stop done, _ = await asyncio.wait(to_wait, return_when=asyncio.FIRST_COMPLETED, timeout=2) - # Remove finished tasks - for task in done: - tasks.remove(task) - if task.exception(): - logger.exception('TUNNEL ERROR') except asyncio.CancelledError: logger.info('Task cancelled') do_stop.set() # ensure we stop logger.debug('Out of loop, stopping tasks: %s, running: %s', tasks, do_stop.is_set()) + # If any task is still running, cancel it for task in tasks: task.cancel() diff --git a/tunnel-server/test/test_app_concurrency.py b/tunnel-server/test/test_concurrency.py similarity index 74% rename from tunnel-server/test/test_app_concurrency.py rename to tunnel-server/test/test_concurrency.py index fa6505a35..d573343ac 100644 --- a/tunnel-server/test/test_app_concurrency.py +++ b/tunnel-server/test/test_concurrency.py @@ -45,14 +45,6 @@ logger = logging.getLogger(__name__) class TestUDSTunnelApp(IsolatedAsyncioTestCase): - async def test_run_app_help(self) -> None: - # Executes the app with --help - async with tuntools.tunnel_app_runner(args=['--help']) as process: - - stdout, stderr = await process.communicate() - self.assertEqual(process.returncode, 0, f'{stdout!r} {stderr!r}') - self.assertEqual(stderr, b'') - self.assertIn(b'usage: udstunnel', stdout) async def client_task(self, host: str, tunnel_port: int, remote_port: int) -> None: received: bytes = b'' @@ -118,8 +110,8 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase): await callback_invoked.wait() self.assertEqual(received, test_str) - async def test_run_app_serve(self) -> None: - concurrent_tasks = 256 + async def test_app_concurrency(self) -> None: + concurrent_tasks = 512 fake_broker_port = 20000 tunnel_server_port = fake_broker_port + 1 remote_port = fake_broker_port + 2 @@ -154,13 +146,15 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase): logfile='/tmp/tunnel_test.log', loglevel='DEBUG', workers=4, + command_timeout=16, # Increase command timeout because heavy load we will create ) as process: + # Create a "bunch" of clients tasks = [ asyncio.create_task( self.client_task(host, tunnel_server_port, remote_port + i) ) - for i in range(concurrent_tasks) + async for i in tools.waitable_range(concurrent_tasks) ] # Wait for all tasks to finish @@ -171,3 +165,52 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase): task.result() # Queue should have all requests (concurrent_tasks*2, one for open and one for close) self.assertEqual(req_queue.qsize(), concurrent_tasks * 2) + + async def test_tunnel_proc_concurrency(self) -> None: + concurrent_tasks = 512 + fake_broker_port = 20000 + tunnel_server_port = fake_broker_port + 1 + remote_port = fake_broker_port + 2 + # Extracts the port from an string that has bX0bwmbPORTbX0bwmb in it + + req_queue: asyncio.Queue[bytes] = asyncio.Queue() + + def extract_port(data: bytes) -> int: + req_queue.put_nowait(data) + if b'bX0bwmb' not in data: + return 12345 # No port, wil not be used because is an "stop" request + return int(data.split(b'bX0bwmb')[1]) + + for host in ('127.0.0.1', '::1'): + if ':' in host: + url = f'http://[{host}]:{fake_broker_port}/uds/rest' + else: + url = f'http://{host}:{fake_broker_port}/uds/rest' + + req_queue = asyncio.Queue() # clear queue + # Use tunnel proc for testing + async with tuntools.create_tunnel_proc( + host, + tunnel_server_port, + response=lambda data: conf.UDS_GET_TICKET_RESPONSE( + host, extract_port(data) + ), + command_timeout=16, # Increase command timeout because heavy load we will create + ) as (cfg, _): + # Create a "bunch" of clients + tasks = [ + asyncio.create_task( + self.client_task(host, tunnel_server_port, remote_port + i) + ) + async for i in tools.waitable_range(concurrent_tasks) + ] + + # Wait for tasks to finish and check for exceptions + await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED) + + # If any exception was raised, raise it + for task in tasks: + task.result() + + # Queue should have all requests (concurrent_tasks*2, one for open and one for close) + self.assertEqual(req_queue.qsize(), concurrent_tasks * 2) diff --git a/tunnel-server/test/test_tunnel.py b/tunnel-server/test/test_tunnel.py index 74202859f..fe494ff13 100644 --- a/tunnel-server/test/test_tunnel.py +++ b/tunnel-server/test/test_tunnel.py @@ -61,12 +61,11 @@ class TestTunnel(IsolatedAsyncioTestCase): # Send invalid commands and see what happens # Commands are 4 bytes length, try with less and more invalid commands - consts.TIMEOUT_COMMAND = 0.1 # type: ignore # timeout is a final variable, but we need to change it for testing speed for i in range(0, 100, 10): # Set timeout to 1 seconds bad_cmd = bytes(random.randint(0, 255) for _ in range(i)) # Some garbage logger.info(f'Testing invalid command with {bad_cmd!r}') - async with tuntools.create_test_tunnel(callback=lambda x: None, port=7770, remote_port=54555) as cfg: + async with tuntools.create_test_tunnel(callback=lambda x: None, port=7770, remote_port=54555, command_timeout=0.1) as cfg: logger_mock = mock.MagicMock() with mock.patch('uds_tunnel.tunnel.logger', logger_mock): # Open connection to tunnel diff --git a/tunnel-server/test/test_udstunnel.py b/tunnel-server/test/test_udstunnel.py index cb4cd121a..d5ef0ce15 100644 --- a/tunnel-server/test/test_udstunnel.py +++ b/tunnel-server/test/test_udstunnel.py @@ -31,13 +31,12 @@ Author: Adolfo Gómez, dkmaster at dkmon dot com import typing import random import asyncio -import string import logging from unittest import IsolatedAsyncioTestCase, mock from uds_tunnel import consts -from .utils import tuntools, tools +from .utils import tuntools, tools, conf logger = logging.getLogger(__name__) @@ -48,8 +47,16 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase): logging.disable(logging.WARNING) return await super().asyncSetUp() + async def test_run_app_help(self) -> None: + # Executes the app with --help + async with tuntools.tunnel_app_runner(args=['--help']) as process: + + stdout, stderr = await process.communicate() + self.assertEqual(process.returncode, 0, f'{stdout!r} {stderr!r}') + self.assertEqual(stderr, b'') + self.assertIn(b'usage: udstunnel', stdout) + async def test_tunnel_fail_cmd(self) -> None: - consts.TIMEOUT_COMMAND = 0.1 # type: ignore # timeout is a final variable, but we need to change it for testing speed # Test on ipv4 and ipv6 for host in ('127.0.0.1', '::1'): # Remote is not really important in this tests, will fail before using it @@ -58,6 +65,7 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase): 7890, # A port not used by any other test '127.0.0.1', 13579, # A port not used by any other test + command_timeout=0.1, ) as (cfg, queue): for i in range(0, 8192, 128): # Set timeout to 1 seconds @@ -102,7 +110,6 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase): self.assertEqual(data, consts.RESPONSE_OK) async def test_tunnel_fail_open(self) -> None: - consts.TIMEOUT_COMMAND = 0.1 # type: ignore # timeout is a final variable, but we need to change it for testing speed for host in ('127.0.0.1', '::1'): # Remote is NOT important in this tests # create a remote server @@ -112,6 +119,7 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase): 7775, server.host, server.port, + command_timeout=0.1, ) as (cfg, queue): for i in range( 0, consts.TICKET_LENGTH - 1, 4 @@ -152,7 +160,7 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase): for tunnel_host in ('127.0.0.1', '::1'): async with tuntools.create_tunnel_proc( tunnel_host, - 7778, + 7778, # Not really used here server.host, server.port, use_fake_http_server=True, diff --git a/tunnel-server/test/utils/fixtures.py b/tunnel-server/test/utils/fixtures.py index 447a7fabe..95a571fa5 100644 --- a/tunnel-server/test/utils/fixtures.py +++ b/tunnel-server/test/utils/fixtures.py @@ -94,6 +94,10 @@ secret = {secret} # defaults to localhost (change if listen address is different from 0.0.0.0) allow = {allow} +# Command timeout. Command reception on tunnel will timeout after this time (in seconds) +# defaults to 3 seconds +command_timeout = {command_timeout} + use_uvloop = {use_uvloop} ''' @@ -121,6 +125,7 @@ def get_config(**overrides) -> typing.Tuple[typing.Dict[str, typing.Any], config 'uds_verify_ssl': random.choice([True, False]), # Random verify uds ssl 'secret': f'secret{random.randint(0, 100)}', # Random secret 'allow': f'{random.randint(0, 255)}.0.0.0', # Random allow + 'command_timeout': random.randint(0, 100), # Random command timeout 'use_uvloop': random.choice([True, False]), # Random use uvloop } values.update(overrides) diff --git a/tunnel-server/test/utils/tools.py b/tunnel-server/test/utils/tools.py index 9606b6951..8a87299d3 100644 --- a/tunnel-server/test/utils/tools.py +++ b/tunnel-server/test/utils/tools.py @@ -195,3 +195,8 @@ async def wait_for_port(host: str, port: int) -> None: return except ConnectionRefusedError: await asyncio.sleep(0.1) + +async def waitable_range(len: int, wait: float = 0.0001) -> typing.AsyncGenerator[int, None]: + for i in range(len): + await asyncio.sleep(wait) + yield i diff --git a/tunnel-server/test/utils/tuntools.py b/tunnel-server/test/utils/tuntools.py index fb24c4635..a5c0f173e 100644 --- a/tunnel-server/test/utils/tuntools.py +++ b/tunnel-server/test/utils/tuntools.py @@ -98,22 +98,29 @@ def create_config_file( finally: pass # Remove the files if they exists - # for filename in (cfgfile, cert_file): - # try: - # os.remove(filename) - # except Exception: - # pass + for filename in (cfgfile, cert_file): + try: + os.remove(filename) + except Exception: + pass @contextlib.asynccontextmanager async def create_tunnel_proc( listen_host: str, listen_port: int, - remote_host: str, - remote_port: int, + remote_host: str = '0.0.0.0', # Not used if response is provided + remote_port: int = 0, # Not used if response is provided *, - response: typing.Optional[typing.Mapping[str, typing.Any]] = None, + response: typing.Optional[ + typing.Union[ + typing.Callable[[bytes], typing.Mapping[str, typing.Any]], + typing.Mapping[str, typing.Any], + ] + ] = None, use_fake_http_server: bool = False, + # Configuration parameters + **kwargs, ) -> typing.AsyncGenerator[ typing.Tuple['config.ConfigurationType', typing.Optional[asyncio.Queue[bytes]]], None, @@ -126,7 +133,7 @@ async def create_tunnel_proc( listen_port (int): Port to listen on remote_host (str): Remote host to connect to remote_port (int): Remote port to connect to - response (typing.Optional[typing.Mapping[str, typing.Any]], optional): Response to send to the tunnel. Defaults to None. + response (typing.Optional[typing.Union[typing.Callable[[bytes], typing.Mapping[str, typing.Any]], typing.Mapping[str, typing.Any]]], optional): Response to send to the client. Defaults to None. use_fake_http_server (bool, optional): If True, a fake http server will be used instead of a mock. Defaults to False. Yields: @@ -134,11 +141,16 @@ async def create_tunnel_proc( and a queue with the data received by the "fake_http_server" if used, or None if not used """ + # Ensure response + if response is None: + response = conf.UDS_GET_TICKET_RESPONSE(remote_host, remote_port) + port = random.randint(20000, 30000) hhost = f'[{listen_host}]' if ':' in listen_host else listen_host args = { 'uds_server': f'http://{hhost}:{port}/uds/rest', } + args.update(kwargs) # Add extra args # If use http server instead of mock # We will setup a different context provider if use_fake_http_server: @@ -149,7 +161,7 @@ async def create_tunnel_proc( typing.Optional[asyncio.Queue[bytes]], None ]: async with create_fake_broker_server( - listen_host, port, response=resp + listen_host, port, response=response or resp ) as queue: try: yield queue @@ -166,7 +178,10 @@ async def create_tunnel_proc( 'uds_tunnel.tunnel.TunnelProtocol._read_from_uds', new_callable=tools.AsyncMock, ) as m: - m.return_value = response + if callable(response): + m.side_effect = lambda cfg, ticket, *args, **kwargs: response(ticket) # type: ignore + else: + m.return_value = response try: yield None finally: @@ -181,10 +196,6 @@ async def create_tunnel_proc( # Load config here also for testing cfg = config.read(cfgfile) - # Ensure response - if response is None: - response = conf.UDS_GET_TICKET_RESPONSE(remote_host, remote_port) - async with provider() as possible_queue: # Stats collector gs = stats.GlobalStats() @@ -283,6 +294,8 @@ async def create_test_tunnel( callback: typing.Callable[[bytes], None], port: typing.Optional[int] = None, remote_port: typing.Optional[int] = None, + # Configuration parameters + **kwargs: typing.Any, ) -> typing.AsyncGenerator['config.ConfigurationType', None]: # Generate a listening server for testing tunnel # Prepare the end of the tunnel @@ -296,6 +309,7 @@ async def create_test_tunnel( address=server.host, port=port or 7771, ipv6=':' in server.host, + **kwargs, ) with mock.patch( 'uds_tunnel.tunnel.TunnelProtocol._read_from_uds', @@ -316,9 +330,11 @@ async def create_fake_broker_server( host: str, port: int, *, - response: typing.Union[ - typing.Callable[[bytes], typing.Mapping[str, typing.Any]], - typing.Mapping[str, typing.Any], + response: typing.Optional[ + typing.Union[ + typing.Callable[[bytes], typing.Mapping[str, typing.Any]], + typing.Mapping[str, typing.Any], + ] ], ) -> typing.AsyncGenerator[asyncio.Queue[bytes], None]: # crate a fake broker server @@ -343,7 +359,7 @@ async def create_fake_broker_server( if callable(response): rr = response(data) else: - rr = response + rr = response or {} resp: bytes = ( b'HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n\r\n'