From 472299f87850c93781c9d9f134b9c48c0c16d389 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adolfo=20G=C3=B3mez=20Garc=C3=ADa?= Date: Tue, 20 Dec 2022 18:31:04 +0100 Subject: [PATCH] Small fixes and more tests. Refactoring. --- tunnel-server/.gitignore | 2 + tunnel-server/coverage.ini | 3 ++ tunnel-server/pytest.ini | 3 +- tunnel-server/src/uds_tunnel/tunnel.py | 20 ++++---- tunnel-server/test/test_app_concurrency.py | 52 +++++++++++++++++-- tunnel-server/test/test_tunnel.py | 2 +- tunnel-server/test/test_tunnel_helpers.py | 14 ++--- tunnel-server/test/test_udstunnel.py | 58 +++++++++++++++------ tunnel-server/test/utils/tools.py | 16 +++--- tunnel-server/test/utils/tuntools.py | 59 +++++++++++++++++----- 10 files changed, 171 insertions(+), 58 deletions(-) diff --git a/tunnel-server/.gitignore b/tunnel-server/.gitignore index e69de29bb..7446a35db 100644 --- a/tunnel-server/.gitignore +++ b/tunnel-server/.gitignore @@ -0,0 +1,2 @@ +htmlcov +.coverage diff --git a/tunnel-server/coverage.ini b/tunnel-server/coverage.ini index 383e85bdb..770fd1e30 100644 --- a/tunnel-server/coverage.ini +++ b/tunnel-server/coverage.ini @@ -1,5 +1,8 @@ [run] dynamic_context = test_function +omit = + test/* + /etc/* [report] skip_empty = True diff --git a/tunnel-server/pytest.ini b/tunnel-server/pytest.ini index c747652c1..ab72b5951 100644 --- a/tunnel-server/pytest.ini +++ b/tunnel-server/pytest.ini @@ -1,5 +1,6 @@ [pytest] -addopts = "-s" +#addopts = "-s" +addopts = --cov --cov-report html --cov-config=coverage.ini -n 12 pythonpath = ./src python_files = tests.py test_*.py *_tests.py log_format = "%(asctime)s %(levelname)s %(message)s" diff --git a/tunnel-server/src/uds_tunnel/tunnel.py b/tunnel-server/src/uds_tunnel/tunnel.py index 8fc6b5041..04c20cac6 100644 --- a/tunnel-server/src/uds_tunnel/tunnel.py +++ b/tunnel-server/src/uds_tunnel/tunnel.py @@ -114,7 +114,7 @@ class TunnelProtocol(asyncio.Protocol): async def open_other_side() -> None: try: - result = await TunnelProtocol.getTicketFromUDS( + result = await TunnelProtocol.get_ticket_from_uds( self.owner.cfg, ticket, self.source ) except Exception as e: @@ -277,10 +277,10 @@ class TunnelProtocol(asyncio.Protocol): logger.debug('Data received: %s', len(data)) self.runner(data) # send data to current runner (command or proxy) - def notifyEnd(self): + def notify_end(self): if self.notify_ticket: asyncio.get_event_loop().create_task( - TunnelProtocol.notifyEndToUds( + TunnelProtocol.notify_end_to_uds( self.owner.cfg, self.notify_ticket, self.stats_manager ) ) @@ -295,7 +295,7 @@ class TunnelProtocol(asyncio.Protocol): self.other_side.transport.close() else: self.stats_manager.close() - self.notifyEnd() + self.notify_end() # helpers @staticmethod @@ -324,12 +324,12 @@ class TunnelProtocol(asyncio.Protocol): int(self.stats_manager.end - self.stats_manager.start), ) # Notify end to uds - self.notifyEnd() + self.notify_end() else: logger.info('TERMINATED %s', self.pretty_source()) @staticmethod - async def _readFromUDS( + async def _read_from_uds( cfg: config.ConfigurationType, ticket: bytes, msg: str, @@ -358,7 +358,7 @@ class TunnelProtocol(asyncio.Protocol): raise Exception(f'TICKET COMMS ERROR: {ticket.decode()} {msg} {e!s}') @staticmethod - async def getTicketFromUDS( + async def get_ticket_from_uds( cfg: config.ConfigurationType, ticket: bytes, address: typing.Tuple[str, int] ) -> typing.MutableMapping[str, typing.Any]: # Sanity checks @@ -374,13 +374,13 @@ class TunnelProtocol(asyncio.Protocol): continue # Correctus raise ValueError(f'TICKET INVALID (char {i} at pos {n})') - return await TunnelProtocol._readFromUDS(cfg, ticket, address[0]) + return await TunnelProtocol._read_from_uds(cfg, ticket, address[0]) @staticmethod - async def notifyEndToUds( + async def notify_end_to_uds( cfg: config.ConfigurationType, ticket: bytes, counter: stats.Stats ) -> None: - await TunnelProtocol._readFromUDS( + await TunnelProtocol._read_from_uds( cfg, ticket, 'stop', diff --git a/tunnel-server/test/test_app_concurrency.py b/tunnel-server/test/test_app_concurrency.py index c2144d46c..574d8ed7d 100644 --- a/tunnel-server/test/test_app_concurrency.py +++ b/tunnel-server/test/test_app_concurrency.py @@ -30,7 +30,8 @@ Author: Adolfo Gómez, dkmaster at dkmon dot com ''' import typing import asyncio -import io +import random +import string import logging from unittest import IsolatedAsyncioTestCase, mock @@ -38,8 +39,12 @@ from uds_tunnel import consts from .utils import tuntools, tools +if typing.TYPE_CHECKING: + from uds_tunnel import config + logger = logging.getLogger(__name__) + class TestUDSTunnelApp(IsolatedAsyncioTestCase): async def test_run_app_help(self) -> None: # Executes the app with --help @@ -49,10 +54,49 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase): self.assertEqual(process.returncode, 0, f'{stdout!r} {stderr!r}') self.assertEqual(stderr, b'') self.assertIn(b'usage: udstunnel', stdout) - + + async def client_task( + self, cfg: 'config.ConfigurationType', host: str, port: int + ) -> None: + received: bytes = b'' + callback_invoked: asyncio.Event = asyncio.Event() + + def callback(data: bytes) -> None: + nonlocal received + received += data + # if data contains EOS marcker ('STREAM_END'), we are done + if b'STREAM_END' in data: + callback_invoked.set() + + async with tools.AsyncTCPServer( + host=host, port=5445, callback=callback + ) as server: + # Create a random ticket with valid format + ticket = ''.join( + random.choice(string.ascii_letters + string.digits) + for _ in range(consts.TICKET_LENGTH) + ).encode() + # Open and send handshake + async with tuntools.open_tunnel_client(cfg, use_tunnel_handshake=True) as ( + creader, + cwriter, + ): + # Now open command with ticket + cwriter.write(consts.COMMAND_OPEN) + # fake ticket, consts.TICKET_LENGTH bytes long, letters and numbers. Use a random ticket, + cwriter.write(ticket) + + await cwriter.drain() + # Read response, should be ok + data = await creader.read(1024) + self.assertEqual( + data, + consts.RESPONSE_OK, + f'Server host: {host}:{port} - Ticket: {ticket!r} - Response: {data!r}', + ) async def test_run_app_serve(self) -> None: for host in ('127.0.0.1', '::1'): async with tuntools.tunnel_app_runner(host, 7777) as process: - print(process) - \ No newline at end of file + # Create a "bunch" of servers and clients + pass diff --git a/tunnel-server/test/test_tunnel.py b/tunnel-server/test/test_tunnel.py index a59fa8351..cc5eb0ce3 100644 --- a/tunnel-server/test/test_tunnel.py +++ b/tunnel-server/test/test_tunnel.py @@ -66,7 +66,7 @@ class TestTunnel(IsolatedAsyncioTestCase): # Set timeout to 1 seconds bad_cmd = bytes(random.randint(0, 255) for _ in range(i)) # Some garbage logger.info(f'Testing invalid command with {bad_cmd!r}') - async with tuntools.create_test_tunnel(callback=lambda x: None) as cfg: + async with tuntools.create_test_tunnel(callback=lambda x: None, port=7770, remote_port=54555) as cfg: logger_mock = mock.MagicMock() with mock.patch('uds_tunnel.tunnel.logger', logger_mock): # Open connection to tunnel diff --git a/tunnel-server/test/test_tunnel_helpers.py b/tunnel-server/test/test_tunnel_helpers.py index aac30c9e5..d5bf0ca65 100644 --- a/tunnel-server/test/test_tunnel_helpers.py +++ b/tunnel-server/test/test_tunnel_helpers.py @@ -65,7 +65,7 @@ class TestTunnelHelpers(IsolatedAsyncioTestCase): # Test some invalid tickets # Valid ticket are consts.TICKET_LENGTH bytes long, and must be A-Z, a-z, 0-9 with mock.patch( - 'uds_tunnel.tunnel.TunnelProtocol._readFromUDS', + 'uds_tunnel.tunnel.TunnelProtocol._read_from_uds', new_callable=tools.AsyncMock, ) as m: m.side_effect = uds_response @@ -77,7 +77,7 @@ class TestTunnelHelpers(IsolatedAsyncioTestCase): ) with self.assertRaises(ValueError): - await tunnel.TunnelProtocol.getTicketFromUDS( + await tunnel.TunnelProtocol.get_ticket_from_uds( cfg, ticket.encode(), conf.CALLER_HOST ) @@ -85,7 +85,7 @@ class TestTunnelHelpers(IsolatedAsyncioTestCase): for i in range(0, 100): # Now some requests with valid tickets # Ensure no exception is raised - ret_value = await tunnel.TunnelProtocol.getTicketFromUDS( + ret_value = await tunnel.TunnelProtocol.get_ticket_from_uds( cfg, ticket.encode(), conf.CALLER_HOST ) # Ensure data returned is correct {host, port, notify} from mock @@ -108,7 +108,7 @@ class TestTunnelHelpers(IsolatedAsyncioTestCase): async def test_notify_end_to_uds_broker(self) -> None: _, cfg = fixtures.get_config() with mock.patch( - 'uds_tunnel.tunnel.TunnelProtocol._readFromUDS', + 'uds_tunnel.tunnel.TunnelProtocol._read_from_uds', new_callable=tools.AsyncMock, ) as m: m.side_effect = uds_response @@ -118,7 +118,7 @@ class TestTunnelHelpers(IsolatedAsyncioTestCase): ticket = conf.NOTIFY_TICKET.encode() for i in range(0, 100): - await tunnel.TunnelProtocol.notifyEndToUds(cfg, ticket, counter) + await tunnel.TunnelProtocol.notify_end_to_uds(cfg, ticket, counter) self.assertEqual(m.call_args[0][0], cfg) self.assertEqual( @@ -159,9 +159,9 @@ class TestTunnelHelpers(IsolatedAsyncioTestCase): await tools.get(fake_uds_server), '{"result":"ok"}', ) - # Now, tests _readFromUDS + # Now, tests _read_from_uds for i in range(100): - ret = await tunnel.TunnelProtocol._readFromUDS( + ret = await tunnel.TunnelProtocol._read_from_uds( cfg, conf.NOTIFY_TICKET.encode(), 'test', {'param': 'value'} ) self.assertEqual(ret, {'result': 'ok'}) diff --git a/tunnel-server/test/test_udstunnel.py b/tunnel-server/test/test_udstunnel.py index 545b3a40a..8cb14add5 100644 --- a/tunnel-server/test/test_udstunnel.py +++ b/tunnel-server/test/test_udstunnel.py @@ -55,9 +55,9 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase): # Remote is not really important in this tests, will fail before using it async with tuntools.create_tunnel_proc( host, - 7777, + 7890, # A port not used by any other test '127.0.0.1', - 12345, + 13579, # A port not used by any other test ) as cfg: for i in range(0, 8192, 128): # Set timeout to 1 seconds @@ -85,9 +85,9 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase): # Remote is not really important in this tests, will return ok before using it (this is a TEST command, not OPEN) async with tuntools.create_tunnel_proc( host, - 7777, + 7891, '127.0.0.1', - 12345, + 13581, ) as cfg: for i in range(10): # Several times # On full, we need the handshake to be done, before connecting @@ -117,10 +117,7 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase): 0, consts.TICKET_LENGTH - 1, 4 ): # All will fail. Any longer will be processed, and mock will return correct don't matter the ticket # Ticket must contain only letters and numbers - ticket = ''.join( - random.choice(string.ascii_letters + string.digits) - for _ in range(i) - ).encode() + ticket = tuntools.get_correct_ticket(i) # On full, we need the handshake to be done, before connecting # Our "test" server will simple "eat" the handshake, but we need to do it async with tuntools.open_tunnel_client( @@ -161,10 +158,7 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase): ) as cfg: for i in range(1): # Create a random ticket with valid format - ticket = ''.join( - random.choice(string.ascii_letters + string.digits) - for _ in range(consts.TICKET_LENGTH) - ).encode() + ticket = tuntools.get_correct_ticket() # On full, we need the handshake to be done, before connecting # Our "test" server will simple "eat" the handshake, but we need to do it async with tuntools.open_tunnel_client( @@ -175,21 +169,53 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase): cwriter.write(ticket) await cwriter.drain() - # Read response + # Read response, should be ok data = await creader.read(1024) - self.assertEqual(data, consts.RESPONSE_OK, f'Tunnel host: {tunnel_host}, server host: {host}') + self.assertEqual( + data, + consts.RESPONSE_OK, + f'Tunnel host: {tunnel_host}, server host: {host}', + ) # Data sent will be received by server # One single write will ensure all data is on same packet - test_str = b'Some Random Data' + bytes(random.randint(0, 255) for _ in range(8192)) + b'STREAM_END' + test_str = ( + b'Some Random Data' + + bytes(random.randint(0, 255) for _ in range(8192)) + + b'STREAM_END' + ) # Clean received data received = b'' # And reset event callback_invoked.clear() - + cwriter.write(test_str) await cwriter.drain() # Wait for callback to be invoked await callback_invoked.wait() self.assertEqual(received, test_str) + + async def test_tunnel_no_remote(self) -> None: + for host in ('127.0.0.1', '::1'): + for tunnel_host in ('127.0.0.1', '::1'): + async with tuntools.create_tunnel_proc( + tunnel_host, + 7888, + host, + 17222, # Any non used port will do the trick + ) as cfg: + ticket = tuntools.get_correct_ticket() + # On full, we need the handshake to be done, before connecting + # Our "test" server will simple "eat" the handshake, but we need to do it + async with tuntools.open_tunnel_client( + cfg, use_tunnel_handshake=True + ) as (creader, cwriter): + cwriter.write(consts.COMMAND_OPEN) + # fake ticket, consts.TICKET_LENGTH bytes long, letters and numbers. Use a random ticket, + cwriter.write(ticket) + + await cwriter.drain() + # Read response + data = await creader.read(1024) + self.assertEqual(data, b'', f'Tunnel host: {tunnel_host}, server host: {host}') diff --git a/tunnel-server/test/utils/tools.py b/tunnel-server/test/utils/tools.py index 1b772bf89..786624d24 100644 --- a/tunnel-server/test/utils/tools.py +++ b/tunnel-server/test/utils/tools.py @@ -38,7 +38,6 @@ from unittest import mock from . import certs - class AsyncMock(mock.MagicMock): async def __call__(self, *args, **kwargs): return super().__call__(*args, **kwargs) @@ -111,7 +110,8 @@ class AsyncTCPServer: port: int _server: typing.Optional[asyncio.AbstractServer] _response: typing.Optional[bytes] - _callback: typing.Optional[typing.Callable[[bytes], None]] + _callback: typing.Optional[typing.Callable[[bytes], typing.Optional[bytes]]] + _processor: typing.Optional[typing.Callable[[asyncio.StreamReader, asyncio.StreamWriter], None]] def __init__( self, @@ -119,26 +119,30 @@ class AsyncTCPServer: *, response: typing.Optional[bytes] = None, host: str = '127.0.0.1', # ip - callback: typing.Optional[typing.Callable[[bytes], None]] = None, + callback: typing.Optional[typing.Callable[[bytes], typing.Optional[bytes]]] = None, + processor: typing.Optional[typing.Callable[[asyncio.StreamReader, asyncio.StreamWriter], None]] = None, ) -> None: self.host = host self.port = port self._server = None self._response = response self._callback = callback + self._processor = processor self.data = b'' async def _handle(self, reader, writer) -> None: + if self._processor is not None: + self._processor(reader, writer) + return while True: data = await reader.read(2048) if not data: break - if self._callback: - self._callback(data) + resp = self._callback(data) if self._callback else self._response - if self._response is not None: + if resp is not None: data = self._response writer.write(data) await writer.drain() diff --git a/tunnel-server/test/utils/tuntools.py b/tunnel-server/test/utils/tuntools.py index 5e206f84e..8ab01eeb3 100644 --- a/tunnel-server/test/utils/tuntools.py +++ b/tunnel-server/test/utils/tuntools.py @@ -30,22 +30,23 @@ Author: Adolfo Gómez, dkmaster at dkmon dot com ''' import asyncio import contextlib -import os +import json import logging +import multiprocessing +import os +import random import socket import ssl -import os -import typing +import string import tempfile +import typing from unittest import mock -import multiprocessing import udstunnel -from uds_tunnel import consts, tunnel, stats, config +from uds_tunnel import config, consts, stats, tunnel from . import certs, conf, fixtures, tools - logger = logging.getLogger(__name__) if typing.TYPE_CHECKING: @@ -100,7 +101,7 @@ async def create_tunnel_proc( remote_host: str, remote_port: int, *, - response: typing.Optional[typing.Mapping[str, typing.Any]] = None + response: typing.Optional[typing.Mapping[str, typing.Any]] = None, ) -> typing.AsyncGenerator['config.ConfigurationType', None]: with create_config_file(listen_host, listen_port) as cfgfile: args = mock.MagicMock() @@ -116,7 +117,7 @@ async def create_tunnel_proc( response = conf.UDS_GET_TICKET_RESPONSE(remote_host, remote_port) with mock.patch( - 'uds_tunnel.tunnel.TunnelProtocol._readFromUDS', + 'uds_tunnel.tunnel.TunnelProtocol._read_from_uds', new_callable=tools.AsyncMock, ) as m: m.return_value = response @@ -214,21 +215,26 @@ async def create_tunnel_server( @contextlib.asynccontextmanager async def create_test_tunnel( - *, callback: typing.Callable[[bytes], None] + *, + callback: typing.Callable[[bytes], None], + port: typing.Optional[int] = None, + remote_port: typing.Optional[int] = None, ) -> typing.AsyncGenerator['config.ConfigurationType', None]: # Generate a listening server for testing tunnel # Prepare the end of the tunnel - async with tools.AsyncTCPServer(port=54876, callback=callback) as server: + async with tools.AsyncTCPServer( + port=remote_port or 54876, callback=callback + ) as server: # Create a tunnel to localhost 13579 # SSl cert for tunnel server with certs.ssl_context(server.host) as (ssl_ctx, _): _, cfg = fixtures.get_config( address=server.host, - port=7777, + port=port or 7777, ipv6=':' in server.host, ) with mock.patch( - 'uds_tunnel.tunnel.TunnelProtocol._readFromUDS', + 'uds_tunnel.tunnel.TunnelProtocol._read_from_uds', new_callable=tools.AsyncMock, ) as m: m.return_value = conf.UDS_GET_TICKET_RESPONSE(server.host, server.port) @@ -241,6 +247,26 @@ async def create_test_tunnel( await tunnel_server.wait_closed() +@contextlib.asynccontextmanager +async def create_fake_broker_server( + response: typing.Mapping[str, typing.Any], port: int = 44443 +) -> typing.AsyncGenerator[None, None]: + # crate a fake broker server + # Ignores request, and sends response + + resp: bytes = b'HTTP/1.1 200 OK\r\n\r\n' + json.dumps(response).encode() + + def callback(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): + writer.write(resp) + writer.close() + + async with tools.AsyncTCPServer(port=port, response=resp) as server: + try: + yield + finally: + pass # nothing to do + + @contextlib.asynccontextmanager async def open_tunnel_client( cfg: 'config.ConfigurationType', @@ -284,7 +310,7 @@ async def tunnel_app_runner( host: typing.Optional[str] = None, port: typing.Optional[int] = None, *, - args: typing.Optional[typing.List[str]] = None + args: typing.Optional[typing.List[str]] = None, ) -> typing.AsyncGenerator['Process', None]: # Ensure we are on src directory if os.path.basename(os.getcwd()) != 'src': @@ -312,3 +338,10 @@ async def tunnel_app_runner( if process.returncode is None: process.terminate() await process.wait() + + +def get_correct_ticket(length: int = consts.TICKET_LENGTH) -> bytes: + return ''.join( + random.choice(string.ascii_letters + string.digits) + for _ in range(length) + ).encode()