From 49dddbfce7226f6abf4b0c0626ec44c57f562b5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adolfo=20G=C3=B3mez=20Garc=C3=ADa?= Date: Sun, 18 Dec 2022 22:16:40 +0100 Subject: [PATCH] Added full tunnel test --- tunnel-server/src/uds_tunnel/proxy.py | 2 +- tunnel-server/src/uds_tunnel/tunnel.py | 13 ++- tunnel-server/src/udstunnel.py | 4 +- tunnel-server/test/test_udstunnel.py | 147 ++++++++++++++++++++++--- tunnel-server/test/utils/tuntools.py | 30 ++++- 5 files changed, 169 insertions(+), 27 deletions(-) diff --git a/tunnel-server/src/uds_tunnel/proxy.py b/tunnel-server/src/uds_tunnel/proxy.py index 062075171..bc3c12ade 100644 --- a/tunnel-server/src/uds_tunnel/proxy.py +++ b/tunnel-server/src/uds_tunnel/proxy.py @@ -51,7 +51,6 @@ class Proxy: def __init__(self, cfg: 'config.ConfigurationType', ns: 'Namespace') -> None: self.cfg = cfg self.ns = ns - self.finished = asyncio.Future() # not done yet # Method responsible of proxying requests async def __call__(self, source: socket.socket, context: 'ssl.SSLContext') -> None: @@ -71,6 +70,7 @@ class Proxy: loop = asyncio.get_running_loop() # Handshake correct in this point, upgrade the connection to TSL and let # the protocol controller do the rest + self.finished = loop.create_future() # Upgrade connection to SSL, and use asyncio to handle the rest try: diff --git a/tunnel-server/src/uds_tunnel/tunnel.py b/tunnel-server/src/uds_tunnel/tunnel.py index fb34c0313..630cedc67 100644 --- a/tunnel-server/src/uds_tunnel/tunnel.py +++ b/tunnel-server/src/uds_tunnel/tunnel.py @@ -66,7 +66,7 @@ class TunnelProtocol(asyncio.Protocol): # counter counter: stats.StatsSingleCounter # If there is a timeout task running - timeout_task: typing.Optional[asyncio.Task] + timeout_task: typing.Optional[asyncio.Task] = None def __init__( self, owner: 'proxy.Proxy', other_side: typing.Optional['TunnelProtocol'] = None @@ -83,6 +83,8 @@ class TunnelProtocol(asyncio.Protocol): self.stats_manager = stats.Stats(owner.ns) self.counter = self.stats_manager.as_sent_counter() self.runner = self.do_command + # Set starting timeout task, se we dont get hunged on connections without data + self.set_timeout(consts.TIMEOUT_COMMAND) # transport is undefined until connection_made is called self.cmd = b'' @@ -90,15 +92,14 @@ class TunnelProtocol(asyncio.Protocol): self.owner = owner self.source = ('', 0) self.destination = ('', 0) - self.timeout_task = None - # Set starting timeout task, se we dont get hunged on connections without data - self.set_timeout(consts.TIMEOUT_COMMAND) def process_open(self) -> None: # Open Command has the ticket behind it if len(self.cmd) < consts.TICKET_LENGTH + consts.COMMAND_LENGTH: + # Reactivate timeout, will be deactivated on do_command + self.set_timeout(consts.TIMEOUT_COMMAND) return # Wait for more data to complete OPEN command # Ticket received, now process it with UDS @@ -192,6 +193,7 @@ class TunnelProtocol(asyncio.Protocol): self.close_connection() async def timeout(self, wait: int) -> None: + """ Timeout can only occur while waiting for a command.""" try: await asyncio.sleep(wait) logger.error('TIMEOUT FROM %s', self.pretty_source()) @@ -213,7 +215,8 @@ class TunnelProtocol(asyncio.Protocol): self.timeout_task = asyncio.create_task(self.timeout(wait)) def clean_timeout(self) -> None: - """Clean the timeout task if any.""" + """Clean the timeout task if any. + """ if self.timeout_task: self.timeout_task.cancel() self.timeout_task = None diff --git a/tunnel-server/src/udstunnel.py b/tunnel-server/src/udstunnel.py index 28cbd6b26..c688dced5 100755 --- a/tunnel-server/src/udstunnel.py +++ b/tunnel-server/src/udstunnel.py @@ -235,7 +235,7 @@ def tunnel_main(args: 'argparse.Namespace') -> None: # If running as root, and requested drop privileges after port bind if os.getuid() == 0 and cfg.user: - logger.debug('Changing to user %s', cfg.user) + logger.debug('Changing to user %s', cfg.user) pwu = pwd.getpwnam(cfg.user) # os.setgid(pwu.pw_gid) os.setuid(pwu.pw_uid) @@ -271,7 +271,7 @@ def tunnel_main(args: 'argparse.Namespace') -> None: prcs = processes.Processes(tunnel_proc_async, cfg, stats_collector.ns) - running.set() + running.set() # Signal we are running with ThreadPoolExecutor(max_workers=256) as executor: try: while running.is_set(): diff --git a/tunnel-server/test/test_udstunnel.py b/tunnel-server/test/test_udstunnel.py index 3931943b6..c9d956607 100644 --- a/tunnel-server/test/test_udstunnel.py +++ b/tunnel-server/test/test_udstunnel.py @@ -31,12 +31,13 @@ Author: Adolfo Gómez, dkmaster at dkmon dot com import typing import random import asyncio +import string import logging from unittest import IsolatedAsyncioTestCase, mock from uds_tunnel import consts -from .utils import tuntools +from .utils import tuntools, tools logger = logging.getLogger(__name__) @@ -49,25 +50,145 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase): async def test_tunnel_fail_cmd(self) -> None: consts.TIMEOUT_COMMAND = 0.1 # type: ignore # timeout is a final variable, but we need to change it for testing speed - for i in range(0, 100, 10): - # Set timeout to 1 seconds - bad_cmd = bytes(random.randint(0, 255) for _ in range(i)) # Some garbage - logger.info(f'Testing invalid command with {bad_cmd!r}') - for host in ('127.0.0.1', '::1'): - # Remote is not really important in this tests, will fail before using it - async with tuntools.create_tunnel_proc( - host, 7777, '127.0.0.1', 12345, workers=1 - ) as cfg: + # Test on ipv4 and ipv6 + for host in ('127.0.0.1', '::1'): + # Remote is not really important in this tests, will fail before using it + async with tuntools.create_tunnel_proc( + host, + 7777, + '127.0.0.1', + 12345, + ) as cfg: + for i in range(0, 8192, 128): + # Set timeout to 1 seconds + bad_cmd = bytes( + random.randint(0, 255) for _ in range(i) + ) # Some garbage + logger.info(f'Testing invalid command with {bad_cmd!r}') # On full, we need the handshake to be done, before connecting - async with tuntools.open_tunnel_client(cfg, use_tunnel_handshake=True) as (creader, cwriter): + # Our "test" server will simple "eat" the handshake, but we need to do it + async with tuntools.open_tunnel_client( + cfg, use_tunnel_handshake=True + ) as (creader, cwriter): cwriter.write(bad_cmd) await cwriter.drain() - # Read response + # Read response data = await creader.read(1024) # if len(bad_cmd) < consts.COMMAND_LENGTH, response will be RESPONSE_ERROR_TIMEOUT if len(bad_cmd) >= consts.COMMAND_LENGTH: self.assertEqual(data, consts.RESPONSE_ERROR_COMMAND) else: self.assertEqual(data, consts.RESPONSE_ERROR_TIMEOUT) - + async def test_tunnel_test(self) -> None: + for host in ('127.0.0.1', '::1'): + # Remote is not really important in this tests, will return ok before using it (this is a TEST command, not OPEN) + async with tuntools.create_tunnel_proc( + host, + 7777, + '127.0.0.1', + 12345, + ) as cfg: + for i in range(10): # Several times + # On full, we need the handshake to be done, before connecting + # Our "test" server will simple "eat" the handshake, but we need to do it + async with tuntools.open_tunnel_client( + cfg, use_tunnel_handshake=True + ) as (creader, cwriter): + cwriter.write(consts.COMMAND_TEST) + await cwriter.drain() + # Read response + data = await creader.read(1024) + self.assertEqual(data, consts.RESPONSE_OK) + + async def test_tunnel_fail_open(self) -> None: + consts.TIMEOUT_COMMAND = 0.1 # type: ignore # timeout is a final variable, but we need to change it for testing speed + for host in ('127.0.0.1', '::1'): + # Remote is NOT important in this tests + # create a remote server + async with tools.AsyncTCPServer(host=host, port=5445) as server: + async with tuntools.create_tunnel_proc( + host, + 7777, + server.host, + server.port, + ) as cfg: + for i in range( + 0, consts.TICKET_LENGTH - 1, 4 + ): # All will fail. Any longer will be processed, and mock will return correct don't matter the ticket + # Ticket must contain only letters and numbers + ticket = ''.join( + random.choice(string.ascii_letters + string.digits) + for _ in range(i) + ).encode() + # On full, we need the handshake to be done, before connecting + # Our "test" server will simple "eat" the handshake, but we need to do it + async with tuntools.open_tunnel_client( + cfg, use_tunnel_handshake=True + ) as (creader, cwriter): + cwriter.write(consts.COMMAND_OPEN) + # fake ticket, consts.TICKET_LENGTH bytes long, letters and numbers. Use a random ticket, + cwriter.write(ticket) + + await cwriter.drain() + # Read response + data = await creader.read(1024) + self.assertEqual(data, consts.RESPONSE_ERROR_TIMEOUT) + + async def test_tunnel_open(self) -> None: + for host in ('127.0.0.1', '::1'): + received: bytes = b'' + callback_invoked: asyncio.Event = asyncio.Event() + + def callback(data: bytes) -> None: + nonlocal received + received += data + # if data contains EOS marcker ('STREAM_END'), we are done + if b'STREAM_END' in data: + callback_invoked.set() + + # Remote is important in this tests + # create a remote server + async with tools.AsyncTCPServer( + host=host, port=5445, callback=callback + ) as server: + async with tuntools.create_tunnel_proc( + host, + 7777, + server.host, + server.port, + ) as cfg: + for i in range(10): + # Create a random valid ticket + ticket = ''.join( + random.choice(string.ascii_letters + string.digits) + for _ in range(consts.TICKET_LENGTH) + ).encode() + # On full, we need the handshake to be done, before connecting + # Our "test" server will simple "eat" the handshake, but we need to do it + async with tuntools.open_tunnel_client( + cfg, use_tunnel_handshake=True + ) as (creader, cwriter): + cwriter.write(consts.COMMAND_OPEN) + # fake ticket, consts.TICKET_LENGTH bytes long, letters and numbers. Use a random ticket, + cwriter.write(ticket) + + await cwriter.drain() + # Read response + data = await creader.read(1024) + self.assertEqual(data, consts.RESPONSE_OK) + + # Data sent will be received by server + # One single write will ensure all data is on same packet + test_str = b'Some Random Data' + bytes(random.randint(0, 255) for _ in range(512)) + b'STREAM_END' + # Clean received data + received = b'' + # And reset event + callback_invoked.clear() + + cwriter.write(test_str) + await cwriter.drain() + + # Wait for callback to be invoked + await callback_invoked.wait() + self.assertEqual(received, test_str) diff --git a/tunnel-server/test/utils/tuntools.py b/tunnel-server/test/utils/tuntools.py index 019799e74..61e677887 100644 --- a/tunnel-server/test/utils/tuntools.py +++ b/tunnel-server/test/utils/tuntools.py @@ -34,10 +34,9 @@ import io import logging import socket import ssl -import tempfile -import threading -import random +import os import typing +import json from unittest import mock import multiprocessing @@ -61,7 +60,7 @@ async def create_tunnel_proc( remote_host: str, remote_port: int, *, - workers: int = 1 + response: typing.Optional[typing.Mapping[str, typing.Any]] = None ) -> typing.AsyncGenerator['config.ConfigurationType', None]: # Create the ssl cert cert, key, password = certs.selfSignedCert(listen_host, use_password=False) @@ -76,28 +75,35 @@ async def create_tunnel_proc( address=listen_host, port=listen_port, ipv6=':' in listen_host, + loglevel='DEBUG', ssl_certificate=cert_file, ssl_certificate_key='', ssl_password=password, ssl_ciphers='', ssl_dhparam='', - workers=workers, ) args = mock.MagicMock() args.config = io.StringIO(fixtures.TEST_CONFIG.format(**values)) args.ipv6 = ':' in listen_host + return_value: typing.Mapping[str, typing.Any] + # Ensure response + if response is None: + response = conf.UDS_GET_TICKET_RESPONSE(remote_host, remote_port) + with mock.patch( 'uds_tunnel.tunnel.TunnelProtocol._readFromUDS', new_callable=tools.AsyncMock, ) as m: - m.return_value = conf.UDS_GET_TICKET_RESPONSE(remote_host, remote_port) + m.return_value = response # Stats collector gs = stats.GlobalStats() # Pipe to send data to tunnel own_end, other_end = multiprocessing.Pipe() + udstunnel.setup_log(cfg) + # Set running flag udstunnel.running.set() @@ -141,6 +147,18 @@ async def create_tunnel_proc( await server.wait_closed() logger.info('Server closed') + # Ensure log file are removed + rootlog = logging.getLogger() + for h in rootlog.handlers: + if isinstance(h, logging.FileHandler): + h.close() + # Remove the file if possible, do not fail + try: + os.unlink(h.baseFilename) + except Exception: + pass + + async def create_tunnel_server( cfg: 'config.ConfigurationType', context: 'ssl.SSLContext'