From 370799912f89073d7f2582cdf64a83fe4d68f69c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adolfo=20G=C3=B3mez=20Garc=C3=ADa?= Date: Sat, 17 Dec 2022 21:03:00 +0100 Subject: [PATCH] Added test to udstunnel man async proc --- tunnel-server/src/uds_tunnel/config.py | 2 + tunnel-server/src/udstunnel.py | 77 ++++--- tunnel-server/test/test_config_file.py | 2 +- tunnel-server/test/test_tunnel.py | 91 +------- tunnel-server/test/test_tunnel_full.py | 111 ---------- tunnel-server/test/test_tunnel_helpers.py | 2 +- tunnel-server/test/test_udstunnel.py | 67 ++++++ tunnel-server/test/utils/certs.py | 18 +- tunnel-server/test/{ => utils}/fixtures.py | 0 tunnel-server/test/utils/tuntools.py | 235 +++++++++++++++++++++ 10 files changed, 369 insertions(+), 236 deletions(-) delete mode 100644 tunnel-server/test/test_tunnel_full.py create mode 100644 tunnel-server/test/test_udstunnel.py rename tunnel-server/test/{ => utils}/fixtures.py (100%) create mode 100644 tunnel-server/test/utils/tuntools.py diff --git a/tunnel-server/src/uds_tunnel/config.py b/tunnel-server/src/uds_tunnel/config.py index 9fae1a679..3b302ecd3 100644 --- a/tunnel-server/src/uds_tunnel/config.py +++ b/tunnel-server/src/uds_tunnel/config.py @@ -56,6 +56,7 @@ class ConfigurationType(typing.NamedTuple): ssl_certificate: str ssl_certificate_key: str + ssl_password: str ssl_ciphers: str ssl_dhparam: str @@ -122,6 +123,7 @@ def read( workers=int(uds.get('workers', '0')) or multiprocessing.cpu_count(), ssl_certificate=uds['ssl_certificate'], ssl_certificate_key=uds.get('ssl_certificate_key', ''), + ssl_password=uds.get('ssl_password', ''), ssl_ciphers=uds.get('ssl_ciphers'), ssl_dhparam=uds.get('ssl_dhparam'), uds_server=uds_server, diff --git a/tunnel-server/src/udstunnel.py b/tunnel-server/src/udstunnel.py index 0e1931995..3d01c19c1 100755 --- a/tunnel-server/src/udstunnel.py +++ b/tunnel-server/src/udstunnel.py @@ -39,6 +39,8 @@ import ssl import socket import logging from concurrent.futures import ThreadPoolExecutor +# event for stop notification +import threading import typing try: @@ -62,12 +64,12 @@ if typing.TYPE_CHECKING: logger = logging.getLogger(__name__) -do_stop = False +running: threading.Event = threading.Event() def stop_signal(signum: int, frame: typing.Any) -> None: - global do_stop - do_stop = True + global running + running.clear() logger.debug('SIGNAL %s, frame: %s', signum, frame) @@ -119,8 +121,13 @@ async def tunnel_proc_async( ] = pipe.recv() if msg: return msg + except EOFError: + logger.debug('Parent process closed connection') + pipe.close() + return None, None except Exception: logger.exception('Receiving data from parent process') + pipe.close() return None, None async def run_server() -> None: @@ -129,11 +136,15 @@ async def tunnel_proc_async( # Generate SSL context context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) - + args: typing.Dict[str, typing.Any] = { + 'certfile': cfg.ssl_certificate, + } if cfg.ssl_certificate_key: - context.load_cert_chain(cfg.ssl_certificate, cfg.ssl_certificate_key) - else: - context.load_cert_chain(cfg.ssl_certificate) + args['keyfile'] = cfg.ssl_certificate_key + if cfg.ssl_password: + args['password'] = cfg.ssl_password + + context.load_cert_chain(**args) if cfg.ssl_ciphers: context.set_ciphers(cfg.ssl_ciphers) @@ -141,29 +152,37 @@ async def tunnel_proc_async( if cfg.ssl_dhparam: context.load_dh_params(cfg.ssl_dhparam) - while True: - address: typing.Optional[typing.Tuple[str, int]] = ('', 0) - try: - (sock, address) = await loop.run_in_executor(None, get_socket) - if not sock: - break # No more sockets, exit - logger.debug(f'CONNECTION from {address!r} (pid: {os.getpid()})') - tasks.append(asyncio.create_task(tunneler(sock, context))) - except Exception: - logger.error('NEGOTIATION ERROR from %s', address[0] if address else 'unknown') + try: + while True: + address: typing.Optional[typing.Tuple[str, int]] = ('', 0) + try: + (sock, address) = await loop.run_in_executor(None, get_socket) + if not sock: + break # No more sockets, exit + logger.debug(f'CONNECTION from {address!r} (pid: {os.getpid()})') + tasks.append(asyncio.create_task(tunneler(sock, context))) + except asyncio.CancelledError: + raise + except Exception: + logger.error('NEGOTIATION ERROR from %s', address[0] if address else 'unknown') + except asyncio.CancelledError: + pass # Stop # create task for server tasks.append(asyncio.create_task(run_server())) - while tasks and not do_stop: - to_wait = tasks[:] # Get a copy of the list, and clean the original - # Wait for tasks to finish - done, _ = await asyncio.wait(to_wait, return_when=asyncio.FIRST_COMPLETED) - # Remove finished tasks - for task in done: - tasks.remove(task) - if task.exception(): - logger.exception('TUNNEL ERROR') + try: + while tasks and running.is_set(): + to_wait = tasks[:] # Get a copy of the list, and clean the original + # Wait for tasks to finish + done, _ = await asyncio.wait(to_wait, return_when=asyncio.FIRST_COMPLETED, timeout=2) + # Remove finished tasks + for task in done: + tasks.remove(task) + if task.exception(): + logger.exception('TUNNEL ERROR') + except asyncio.CancelledError: + running.clear() # ensure we stop # If any task is still running, cancel it for task in tasks: @@ -244,16 +263,18 @@ def tunnel_main(args: 'argparse.Namespace') -> None: signal.signal(signal.SIGINT, stop_signal) signal.signal(signal.SIGTERM, stop_signal) except Exception as e: - # Signal not available on threads, and testing uses threads + # Signal not available on threads, and we use threads on tests, + # so we will ignore this because on tests signals are not important logger.warning('Signal not available: %s', e) stats_collector = stats.GlobalStats() prcs = processes.Processes(tunnel_proc_async, cfg, stats_collector.ns) + running.set() with ThreadPoolExecutor(max_workers=256) as executor: try: - while not do_stop: + while running.is_set(): try: client, addr = sock.accept() logger.info('CONNECTION from %s', addr) diff --git a/tunnel-server/test/test_config_file.py b/tunnel-server/test/test_config_file.py index 349b90b2a..5b9eb78d4 100644 --- a/tunnel-server/test/test_config_file.py +++ b/tunnel-server/test/test_config_file.py @@ -32,7 +32,7 @@ import hashlib from unittest import TestCase -from . import fixtures +from .utils import fixtures class TestConfigFile(TestCase): diff --git a/tunnel-server/test/test_tunnel.py b/tunnel-server/test/test_tunnel.py index ba23f865a..d7e97a5d1 100644 --- a/tunnel-server/test/test_tunnel.py +++ b/tunnel-server/test/test_tunnel.py @@ -30,10 +30,7 @@ Author: Adolfo Gómez, dkmaster at dkmon dot com ''' import typing import random -import asyncio -import contextlib import socket -import ssl import logging import multiprocessing from unittest import IsolatedAsyncioTestCase, mock @@ -41,11 +38,7 @@ from unittest import IsolatedAsyncioTestCase, mock from udstunnel import process_connection from uds_tunnel import tunnel, consts -from . import fixtures -from .utils import tools, certs, conf - -if typing.TYPE_CHECKING: - from uds_tunnel import config +from .utils import tuntools logger = logging.getLogger(__name__) @@ -63,16 +56,16 @@ class TestTunnel(IsolatedAsyncioTestCase): # Send invalid commands and see what happens # Commands are 4 bytes length, try with less and more invalid commands + consts.TIMEOUT_COMMAND = 0.1 # type: ignore # timeout is a final variable, but we need to change it for testing speed for i in range(0, 100, 10): # Set timeout to 1 seconds bad_cmd = bytes(random.randint(0, 255) for _ in range(i)) # Some garbage - consts.TIMEOUT_COMMAND = 0.1 # type: ignore # timeout is a final variable, but we need to change it for testing speed logger.info(f'Testing invalid command with {bad_cmd!r}') - async with TestTunnel.create_test_tunnel(callback=lambda x: None) as cfg: + async with tuntools.create_test_tunnel(callback=lambda x: None) as cfg: logger_mock = mock.MagicMock() with mock.patch('uds_tunnel.tunnel.logger', logger_mock): # Open connection to tunnel - async with TestTunnel.open_tunnel(cfg) as (reader, writer): + async with tuntools.open_tunnel_client(cfg) as (reader, writer): # Send data writer.write(bad_cmd) await writer.drain() @@ -151,79 +144,3 @@ class TestTunnel(IsolatedAsyncioTestCase): # recv()[0] will be a copy of the socket, we don't care about it self.assertEqual(other_conn.recv()[1], ('host', 'port')) - @staticmethod - async def create_tunnel_server( - cfg: 'config.ConfigurationType', context: 'ssl.SSLContext' - ) -> 'asyncio.Server': - # Create fake proxy - proxy = mock.MagicMock() - proxy.cfg = cfg - proxy.ns = mock.MagicMock() - proxy.ns.current = 0 - proxy.ns.total = 0 - proxy.ns.sent = 0 - proxy.ns.recv = 0 - proxy.counter = 0 - - loop = asyncio.get_running_loop() - - # Create an asyncio listen socket on cfg.listen_host:cfg.listen_port - return await loop.create_server( - lambda: tunnel.TunnelProtocol(proxy), - cfg.listen_address, - cfg.listen_port, - ssl=context, - family=socket.AF_INET6 - if cfg.listen_ipv6 or ':' in cfg.listen_address - else socket.AF_INET, - ) - - @staticmethod - @contextlib.asynccontextmanager - async def create_test_tunnel( - *, callback: typing.Callable[[bytes], None] - ) -> typing.AsyncGenerator['config.ConfigurationType', None]: - # Generate a listening server for testing tunnel - # Prepare the end of the tunnel - async with tools.AsyncTCPServer(port=54876, callback=callback) as server: - # Create a tunnel to localhost 13579 - # SSl cert for tunnel server - with certs.ssl_context(server.host) as (ssl_ctx, _): - _, cfg = fixtures.get_config( - address=server.host, - port=7777, - ) - with mock.patch( - 'uds_tunnel.tunnel.TunnelProtocol._readFromUDS', - new_callable=tools.AsyncMock, - ) as m: - m.return_value = conf.UDS_GET_TICKET_RESPONSE( - server.host, server.port - ) - - tunnel_server = await TestTunnel.create_tunnel_server(cfg, ssl_ctx) - yield cfg - tunnel_server.close() - await tunnel_server.wait_closed() - - @staticmethod - @contextlib.asynccontextmanager - async def open_tunnel( - cfg: 'config.ConfigurationType', - ) -> typing.AsyncGenerator[ - typing.Tuple[asyncio.StreamReader, asyncio.StreamWriter], None - ]: - """opens an ssl socket to the tunnel server""" - if cfg.listen_ipv6 or ':' in cfg.listen_address: - family = socket.AF_INET6 - else: - family = socket.AF_INET - context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - context.check_hostname = False - context.verify_mode = ssl.CERT_NONE - reader, writer = await asyncio.open_connection( - cfg.listen_address, cfg.listen_port, ssl=context, family=family - ) - yield reader, writer - writer.close() - await writer.wait_closed() diff --git a/tunnel-server/test/test_tunnel_full.py b/tunnel-server/test/test_tunnel_full.py deleted file mode 100644 index 2c6e04df5..000000000 --- a/tunnel-server/test/test_tunnel_full.py +++ /dev/null @@ -1,111 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Copyright (c) 2022 Virtual Cable S.L.U. -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without modification, -# are permitted provided that the following conditions are met: -# -# * Redistributions of source code must retain the above copyright notice, -# this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# * Neither the name of Virtual Cable S.L. nor the names of its contributors -# may be used to endorse or promote products derived from this software -# without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -''' -Author: Adolfo Gómez, dkmaster at dkmon dot com -''' -import typing -import random -import asyncio -import contextlib -import io -import socket -import ssl -import logging -import multiprocessing -import tempfile -import threading -from unittest import IsolatedAsyncioTestCase, mock - -from uds_tunnel import tunnel, consts -import udstunnel - - -from . import fixtures -from .utils import tools, certs, conf - -if typing.TYPE_CHECKING: - from uds_tunnel import config - -logger = logging.getLogger(__name__) - - -class TestTunnel(IsolatedAsyncioTestCase): - @staticmethod - @contextlib.contextmanager - def create_tunnel_thread( - listen_host: str, - listen_port: int, - remote_host: str, - remote_port: int, - *, - workers: int = 1 - ) -> typing.Generator[None, None, None]: - # Create the ssl cert - cert, key, password = certs.selfSignedCert(listen_host) - # Create the certificate file on /tmp - with tempfile.NamedTemporaryFile() as cert_file: - cert_file.write(cert.encode()) - cert_file.write(key.encode()) - cert_file.flush() - - # Config file for the tunnel, ignore readed - values, _ = fixtures.get_config( - address=listen_host, - port=listen_port, - ssl_certificate=cert_file.name, - ssl_certificate_key='', - ssl_ciphers='', - ssl_dhparam='', - workers=workers, - ) - args = mock.MagicMock() - args.config = io.StringIO(fixtures.TEST_CONFIG.format(**values)) - args.ipv6 = ':' in listen_host - - with mock.patch( - 'uds_tunnel.tunnel.TunnelProtocol._readFromUDS', - new_callable=tools.AsyncMock, - ) as m: - m.return_value = conf.UDS_GET_TICKET_RESPONSE(remote_host, remote_port) - - # Create a thread to run the tunnel, udstunnel.tunnel_main will block - # until the tunnel is closed - thread = threading.Thread(target=udstunnel.tunnel_main, args=(args,)) - thread.start() - yield - # Signal stop to thead - udstunnel.do_stop = True - - # Wait for thread to finish - thread.join() - - async def test_tunnel_full(self) -> None: - with self.create_tunnel_thread( - '127.0.0.1', 7777, '127.0.0.1', 12345, workers=1 - ): - await asyncio.sleep(4) diff --git a/tunnel-server/test/test_tunnel_helpers.py b/tunnel-server/test/test_tunnel_helpers.py index 9e3cfcd76..da0e85ed9 100644 --- a/tunnel-server/test/test_tunnel_helpers.py +++ b/tunnel-server/test/test_tunnel_helpers.py @@ -37,7 +37,7 @@ from unittest import IsolatedAsyncioTestCase, mock from uds_tunnel import tunnel, consts -from . import fixtures +from .utils import fixtures from .utils import tools, conf logger = logging.getLogger(__name__) diff --git a/tunnel-server/test/test_udstunnel.py b/tunnel-server/test/test_udstunnel.py new file mode 100644 index 000000000..492f64391 --- /dev/null +++ b/tunnel-server/test/test_udstunnel.py @@ -0,0 +1,67 @@ +# -*- coding: utf-8 -*- +# +# Copyright (c) 2022 Virtual Cable S.L.U. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of Virtual Cable S.L. nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +''' +Author: Adolfo Gómez, dkmaster at dkmon dot com +''' +import typing +import random +import asyncio +import logging +from unittest import IsolatedAsyncioTestCase, mock + +from uds_tunnel import consts + +from .utils import tuntools + +logger = logging.getLogger(__name__) + + +class TestUDSTunnel(IsolatedAsyncioTestCase): + + async def test_tunnel_fail_cmd_full(self) -> None: + consts.TIMEOUT_COMMAND = 0.1 # type: ignore # timeout is a final variable, but we need to change it for testing speed + for i in range(0, 100, 10): + # Set timeout to 1 seconds + bad_cmd = bytes(random.randint(0, 255) for _ in range(i)) # Some garbage + logger.info(f'Testing invalid command with {bad_cmd!r}') + async with tuntools.create_tunnel_proc( + '127.0.0.1', 7777, '127.0.0.1', 12345, workers=1 + ) as cfg: + # On full, we need the handshake to be done, before connecting + async with tuntools.open_tunnel_client(cfg, use_tunnel_handshake=True) as (creader, cwriter): + cwriter.write(bad_cmd) + await cwriter.drain() + # Read response + data = await creader.read(1024) + # if len(bad_cmd) < consts.COMMAND_LENGTH, response will be RESPONSE_ERROR_TIMEOUT + if len(bad_cmd) >= consts.COMMAND_LENGTH: + self.assertEqual(data, consts.RESPONSE_ERROR_COMMAND) + else: + self.assertEqual(data, consts.RESPONSE_ERROR_TIMEOUT) + + diff --git a/tunnel-server/test/utils/certs.py b/tunnel-server/test/utils/certs.py index 3a1623381..9a1ac5a63 100644 --- a/tunnel-server/test/utils/certs.py +++ b/tunnel-server/test/utils/certs.py @@ -46,7 +46,7 @@ from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa -def selfSignedCert(ip: str) -> typing.Tuple[str, str, str]: +def selfSignedCert(ip: str, use_password: bool = True) -> typing.Tuple[str, str, str]: key = rsa.generate_private_key( public_exponent=65537, key_size=2048, @@ -72,14 +72,16 @@ def selfSignedCert(ip: str) -> typing.Tuple[str, str, str]: .add_extension(san, False) .sign(key, hashes.SHA256(), default_backend()) ) - + args: typing.Dict[str, typing.Any] = { + 'encoding': serialization.Encoding.PEM, + 'format': serialization.PrivateFormat.TraditionalOpenSSL, + } + if use_password: + args['encryption_algorithm'] = serialization.BestAvailableEncryption(password.encode()) + else: + args['encryption_algorithm'] = serialization.NoEncryption() return ( - key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.BestAvailableEncryption( - password.encode() - ), + key.private_bytes(**args ).decode(), cert.public_bytes(encoding=serialization.Encoding.PEM).decode(), password, diff --git a/tunnel-server/test/fixtures.py b/tunnel-server/test/utils/fixtures.py similarity index 100% rename from tunnel-server/test/fixtures.py rename to tunnel-server/test/utils/fixtures.py diff --git a/tunnel-server/test/utils/tuntools.py b/tunnel-server/test/utils/tuntools.py new file mode 100644 index 000000000..9c93e6983 --- /dev/null +++ b/tunnel-server/test/utils/tuntools.py @@ -0,0 +1,235 @@ +# -*- coding: utf-8 -*- +# +# Copyright (c) 2022 Virtual Cable S.L.U. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of Virtual Cable S.L. nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +''' +Author: Adolfo Gómez, dkmaster at dkmon dot com +''' +import asyncio +import contextlib +import io +import logging +import socket +import ssl +import tempfile +import threading +import random +import typing +from unittest import mock +import multiprocessing + +import udstunnel +from uds_tunnel import consts, tunnel, stats + +from . import certs, conf, fixtures, tools + +if typing.TYPE_CHECKING: + from uds_tunnel import config + from multiprocessing.connection import Connection + + +logger = logging.getLogger(__name__) + + +@contextlib.asynccontextmanager +async def create_tunnel_proc( + listen_host: str, + listen_port: int, + remote_host: str, + remote_port: int, + *, + workers: int = 1 +) -> typing.AsyncGenerator['config.ConfigurationType', None]: + # Create the ssl cert + cert, key, password = certs.selfSignedCert(listen_host, use_password=False) + # Create the certificate file on /tmp + cert_file = '/tmp/tunnel_full_cert.pem' + with open(cert_file, 'w') as f: + f.write(key) + f.write(cert) + + # Config file for the tunnel, ignore readed + values, cfg = fixtures.get_config( + address=listen_host, + port=listen_port, + ssl_certificate=cert_file, + ssl_certificate_key='', + ssl_password=password, + ssl_ciphers='', + ssl_dhparam='', + workers=workers, + ) + args = mock.MagicMock() + args.config = io.StringIO(fixtures.TEST_CONFIG.format(**values)) + args.ipv6 = ':' in listen_host + + with mock.patch( + 'uds_tunnel.tunnel.TunnelProtocol._readFromUDS', + new_callable=tools.AsyncMock, + ) as m: + m.return_value = conf.UDS_GET_TICKET_RESPONSE(remote_host, remote_port) + + # Stats collector + gs = stats.GlobalStats() + # Pipe to send data to tunnel + own_end, other_end = multiprocessing.Pipe() + + # Set running flag + udstunnel.running.set() + + # Create the tunnel task + task = asyncio.create_task(udstunnel.tunnel_proc_async(other_end, cfg, gs.ns)) + + # Create a small asyncio server that reads the handshake, + # and sends the socket to the tunnel_proc_async using the pipe + # the pipe message will be typing.Tuple[socket.socket, typing.Tuple[str, int]] + # socket and address + async def client_connected_db(reader, writer): + # Read the handshake + data = await reader.read(1024) + # For testing, we ignore the handshake value + # Send the socket to the tunnel + own_end.send( + ( + writer.get_extra_info('socket').dup(), + writer.get_extra_info('peername'), + ) + ) + # Close the socket + writer.close() + + server = await asyncio.start_server( + client_connected_db, + listen_host, + listen_port, + ) + try: + yield cfg + finally: + # Close the pipe (both ends) + own_end.close() + + task.cancel() + # wait for the task to finish + await task + + server.close() + await server.wait_closed() + logger.info('Server closed') + + +async def create_tunnel_server( + cfg: 'config.ConfigurationType', context: 'ssl.SSLContext' +) -> 'asyncio.Server': + # Create fake proxy + proxy = mock.MagicMock() + proxy.cfg = cfg + proxy.ns = mock.MagicMock() + proxy.ns.current = 0 + proxy.ns.total = 0 + proxy.ns.sent = 0 + proxy.ns.recv = 0 + proxy.counter = 0 + + loop = asyncio.get_running_loop() + + # Create an asyncio listen socket on cfg.listen_host:cfg.listen_port + return await loop.create_server( + lambda: tunnel.TunnelProtocol(proxy), + cfg.listen_address, + cfg.listen_port, + ssl=context, + family=socket.AF_INET6 + if cfg.listen_ipv6 or ':' in cfg.listen_address + else socket.AF_INET, + ) + + +@contextlib.asynccontextmanager +async def create_test_tunnel( + *, callback: typing.Callable[[bytes], None] +) -> typing.AsyncGenerator['config.ConfigurationType', None]: + # Generate a listening server for testing tunnel + # Prepare the end of the tunnel + async with tools.AsyncTCPServer(port=54876, callback=callback) as server: + # Create a tunnel to localhost 13579 + # SSl cert for tunnel server + with certs.ssl_context(server.host) as (ssl_ctx, _): + _, cfg = fixtures.get_config( + address=server.host, + port=7777, + ) + with mock.patch( + 'uds_tunnel.tunnel.TunnelProtocol._readFromUDS', + new_callable=tools.AsyncMock, + ) as m: + m.return_value = conf.UDS_GET_TICKET_RESPONSE(server.host, server.port) + + tunnel_server = await create_tunnel_server(cfg, ssl_ctx) + try: + yield cfg + finally: + tunnel_server.close() + await tunnel_server.wait_closed() + + +@contextlib.asynccontextmanager +async def open_tunnel_client( + cfg: 'config.ConfigurationType', + use_tunnel_handshake: bool = False, +) -> typing.AsyncGenerator[ + typing.Tuple[asyncio.StreamReader, asyncio.StreamWriter], None +]: + """opens an ssl socket to the tunnel server""" + loop = asyncio.get_running_loop() + if cfg.listen_ipv6 or ':' in cfg.listen_address: + family = socket.AF_INET6 + else: + family = socket.AF_INET + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + + if not use_tunnel_handshake: + reader, writer = await asyncio.open_connection( + cfg.listen_address, cfg.listen_port, ssl=context, family=family + ) + else: + # Open the socket, send handshake and then upgrade to ssl, non blocking + sock = socket.socket(family, socket.SOCK_STREAM) + # Set socket to non blocking + sock.setblocking(False) + await loop.sock_connect(sock, (cfg.listen_address, cfg.listen_port)) + await loop.sock_sendall(sock, consts.HANDSHAKE_V1) + # upgrade to ssl + reader, writer = await asyncio.open_connection( + sock=sock, ssl=context, server_hostname=cfg.listen_address + ) + try: + yield reader, writer + finally: + writer.close() + await writer.wait_closed()