From 36fca66c9a7425373564ba7fce629feb39952087 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adolfo=20G=C3=B3mez=20Garc=C3=ADa?= Date: Wed, 14 Dec 2022 16:09:33 +0100 Subject: [PATCH] Adding tests and improving tunnel server --- tunnel-server/src/uds_tunnel/config.py | 6 +- tunnel-server/src/uds_tunnel/proxy.py | 18 ++- tunnel-server/src/uds_tunnel/tunnel.py | 20 ++-- tunnel-server/test/fixtures.py | 66 +++++++++- tunnel-server/test/test_config_file.py | 39 ++---- tunnel-server/test/test_tunnel.py | 160 +++++++++++++++++++++++++ tunnel-server/test/utils/__init__.py | 0 tunnel-server/test/utils/certs.py | 52 ++++++++ tunnel-server/test/utils/tools.py | 109 +++++++++++++++++ 9 files changed, 422 insertions(+), 48 deletions(-) create mode 100644 tunnel-server/test/test_tunnel.py create mode 100644 tunnel-server/test/utils/__init__.py create mode 100644 tunnel-server/test/utils/certs.py create mode 100644 tunnel-server/test/utils/tools.py diff --git a/tunnel-server/src/uds_tunnel/config.py b/tunnel-server/src/uds_tunnel/config.py index dc8b32ce6..06004bbf5 100644 --- a/tunnel-server/src/uds_tunnel/config.py +++ b/tunnel-server/src/uds_tunnel/config.py @@ -60,6 +60,8 @@ class ConfigurationType(typing.NamedTuple): uds_server: str uds_token: str + uds_timeout: int + uds_verify_ssl: bool secret: str allow: typing.Set[str] @@ -122,13 +124,15 @@ def read( ssl_dhparam=uds.get('ssl_dhparam'), uds_server=uds_server, uds_token=uds.get('uds_token', 'unauthorized'), + uds_timeout=int(uds.get('uds_timeout', '10')), + uds_verify_ssl=uds.get('uds_verify_ssl', 'true').lower() == 'true', secret=secret, allow=set(uds.get('allow', '127.0.0.1').split(',')), use_uvloop=uds.get('use_uvloop', 'true').lower() == 'true', ) except ValueError as e: raise Exception( - f'Mandatory configuration file in incorrect format: {e.args[0]}. Please, revise {CONFIGFILE}' + f'Mandatory configuration file in incorrect format: {e.args[0]}. Please, revise {CONFIGFILE}' ) except KeyError as e: raise Exception( diff --git a/tunnel-server/src/uds_tunnel/proxy.py b/tunnel-server/src/uds_tunnel/proxy.py index 75b29546e..efeffc24a 100644 --- a/tunnel-server/src/uds_tunnel/proxy.py +++ b/tunnel-server/src/uds_tunnel/proxy.py @@ -55,6 +55,8 @@ class Proxy: async def __call__(self, source: socket.socket, context: 'ssl.SSLContext') -> None: try: await self.proxy(source, context) + except asyncio.CancelledError: + pass # Return on cancel except Exception as e: # get source ip address try: @@ -69,11 +71,15 @@ class Proxy: # the protocol controller do the rest # Upgrade connection to SSL, and use asyncio to handle the rest - transport: 'asyncio.transports.Transport' - protocol: tunnel.TunnelProtocol - (transport, protocol) = await loop.connect_accepted_socket( # type: ignore - lambda: tunnel.TunnelProtocol(self), source, ssl=context - ) + try: + protocol: tunnel.TunnelProtocol + # (connect accepted loop not present on AbastractEventLoop definition < 3.10) + (_, protocol) = await loop.connect_accepted_socket( # type: ignore + lambda: tunnel.TunnelProtocol(self), source, ssl=context + ) + + await protocol.finished + except asyncio.CancelledError: + pass # Return on cancel - await protocol.finished return diff --git a/tunnel-server/src/uds_tunnel/tunnel.py b/tunnel-server/src/uds_tunnel/tunnel.py index 24ed7c593..23554556b 100644 --- a/tunnel-server/src/uds_tunnel/tunnel.py +++ b/tunnel-server/src/uds_tunnel/tunnel.py @@ -52,7 +52,7 @@ class TunnelProtocol(asyncio.Protocol): transport: 'asyncio.transports.Transport' other_side: 'TunnelProtocol' # Current state - runner: typing.Any # In fact, typing.Callable[[bytes], None], but mypy complains on its check + runner: typing.Any # In fact, typing.Callable[[bytes], None], but mypy complains on checking variables that are callables on classes # Command buffer cmd: bytes # Ticket @@ -91,7 +91,7 @@ class TunnelProtocol(asyncio.Protocol): self.source = ('', 0) self.destination = ('', 0) - def process_open(self): + def process_open(self) -> None: # Open Command has the ticket behind it if len(self.cmd) < consts.TICKET_LENGTH + consts.COMMAND_LENGTH: @@ -275,7 +275,7 @@ class TunnelProtocol(asyncio.Protocol): logger.info('TERMINATED %s', self.pretty_source()) @staticmethod - async def _getUdsUrl( + async def _readFromUDS( cfg: config.ConfigurationType, ticket: bytes, msg: str, @@ -289,10 +289,14 @@ class TunnelProtocol(asyncio.Protocol): url += '?' + '&'.join( [f'{key}={value}' for key, value in queryParams.items()] ) + # Set options + options: typing.Dict[str, typing.Any] = {'timeout': cfg.uds_timeout} + if cfg.uds_verify_ssl is False: + options['ssl'] = False # Requests url with aiohttp async with aiohttp.ClientSession() as session: - async with session.get(url) as r: + async with session.get(url, **options) as r: if not r.ok: raise Exception(await r.text()) return await r.json() @@ -305,7 +309,7 @@ class TunnelProtocol(asyncio.Protocol): ) -> typing.MutableMapping[str, typing.Any]: # Sanity checks if len(ticket) != consts.TICKET_LENGTH: - raise Exception(f'TICKET INVALID (len={len(ticket)})') + raise ValueError(f'TICKET INVALID (len={len(ticket)})') for n, i in enumerate(ticket.decode(errors='ignore')): if ( @@ -314,15 +318,15 @@ class TunnelProtocol(asyncio.Protocol): or (i >= 'A' and i <= 'Z') ): continue # Correctus - raise Exception(f'TICKET INVALID (char {i} at pos {n})') + raise ValueError(f'TICKET INVALID (char {i} at pos {n})') - return await TunnelProtocol._getUdsUrl(cfg, ticket, address[0]) + return await TunnelProtocol._readFromUDS(cfg, ticket, address[0]) @staticmethod async def notifyEndToUds( cfg: config.ConfigurationType, ticket: bytes, counter: stats.Stats ) -> None: - await TunnelProtocol._getUdsUrl( + await TunnelProtocol._readFromUDS( cfg, ticket, 'stop', diff --git a/tunnel-server/test/fixtures.py b/tunnel-server/test/fixtures.py index b2e64e7a4..6e3d51b9d 100644 --- a/tunnel-server/test/fixtures.py +++ b/tunnel-server/test/fixtures.py @@ -1,3 +1,39 @@ +# -*- coding: utf-8 -*- +# +# Copyright (c) 2022 Virtual Cable S.L.U. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of Virtual Cable S.L. nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +''' +Author: Adolfo Gómez, dkmaster at dkmon dot com +''' +import typing +import io +import string +import random + +from uds_tunnel import config TEST_CONFIG='''# Sample UDS tunnel configuration @@ -41,6 +77,8 @@ ssl_dhparam = {ssl_dhparam} # https://www.example.com:14333/uds/rest/tunnel/ticket uds_server = {uds_server} uds_token = {uds_token} +uds_timeout = {uds_timeout} +uds_verify_ssl = {uds_verify_ssl} # Secret to get access to admin commands (Currently only stats commands). No default for this. # Admin commands and only allowed from "allow" ips @@ -51,4 +89,30 @@ secret = {secret} # Only use IPs, no networks allowed # defaults to localhost (change if listen address is different from 0.0.0.0) allow = {allow} -''' \ No newline at end of file +''' + +def get_config(**overrides) -> typing.Tuple[typing.Mapping[str, typing.Any], config.ConfigurationType]: + values: typing.Dict[str, typing.Any] = { + 'pidfile': f'/tmp/uds_tunnel_{random.randint(0, 100)}.pid', # Random pid file + 'user': f'user{random.randint(0, 100)}', # Random user + 'loglevel': random.choice(['DEBUG', 'INFO', 'WARNING', 'ERROR']), # Random log level + 'logfile': f'/tmp/uds_tunnel_{random.randint(0, 100)}.log', # Random log file + 'logsize': random.randint(0, 100), # Random log size + 'lognumber': random.randint(0, 100), # Random log number + 'address': f'{random.randint(0, 255)}.{random.randint(0, 255)}.{random.randint(0, 255)}.{random.randint(0, 255)}', # Random address + 'workers': random.randint(1, 100), # Random workers, 0 will return as many as cpu cores + 'ssl_certificate': f'/tmp/uds_tunnel_{random.randint(0, 100)}.crt', # Random ssl certificate + 'ssl_certificate_key': f'/tmp/uds_tunnel_{random.randint(0, 100)}.key', # Random ssl certificate key + 'ssl_ciphers': f'ciphers{random.randint(0, 100)}', # Random ssl ciphers + 'ssl_dhparam': f'/tmp/uds_tunnel_{random.randint(0, 100)}.dh', # Random ssl dhparam + 'uds_server': f'https://uds_server{random.randint(0, 100)}/some_path', # Random uds server + 'uds_token': f'uds_token{"".join(random.choices(string.ascii_uppercase + string.digits, k=32))}', # Random uds token + 'uds_timeout': random.randint(0, 100), # Random uds timeout + 'uds_verify_ssl': random.choice([True, False]), # Random verify uds ssl + 'secret': f'secret{random.randint(0, 100)}', # Random secret + 'allow': f'{random.randint(0, 255)}.0.0.0', # Random allow + } + values.update(overrides) + config_file = io.StringIO(TEST_CONFIG.format(**values)) + # Read it + return values, config.read(config_file) diff --git a/tunnel-server/test/test_config_file.py b/tunnel-server/test/test_config_file.py index 8168a32d6..349b90b2a 100644 --- a/tunnel-server/test/test_config_file.py +++ b/tunnel-server/test/test_config_file.py @@ -28,54 +28,30 @@ ''' Author: Adolfo Gómez, dkmaster at dkmon dot com ''' -import typing import hashlib -import string -import io -import random from unittest import TestCase -from uds_tunnel import config - from . import fixtures + class TestConfigFile(TestCase): def test_config_file(self) -> None: # Test in-memory configuration files ramdomly created for _ in range(100): - values: typing.Mapping[str, typing.Any] = { - 'pidfile': f'/tmp/uds_tunnel_{random.randint(0, 100)}.pid', # Random pid file - 'user': f'user{random.randint(0, 100)}', # Random user - 'loglevel': random.choice(['DEBUG', 'INFO', 'WARNING', 'ERROR']), # Random log level - 'logfile': f'/tmp/uds_tunnel_{random.randint(0, 100)}.log', # Random log file - 'logsize': random.randint(0, 100), # Random log size - 'lognumber': random.randint(0, 100), # Random log number - 'address': f'{random.randint(0, 255)}.{random.randint(0, 255)}.{random.randint(0, 255)}.{random.randint(0, 255)}', # Random address - 'workers': random.randint(1, 100), # Random workers, 0 will return as many as cpu cores - 'ssl_certificate': f'/tmp/uds_tunnel_{random.randint(0, 100)}.crt', # Random ssl certificate - 'ssl_certificate_key': f'/tmp/uds_tunnel_{random.randint(0, 100)}.key', # Random ssl certificate key - 'ssl_ciphers': f'ciphers{random.randint(0, 100)}', # Random ssl ciphers - 'ssl_dhparam': f'/tmp/uds_tunnel_{random.randint(0, 100)}.dh', # Random ssl dhparam - 'uds_server': f'https://uds_server{random.randint(0, 100)}/some_path', # Random uds server - 'uds_token': f'uds_token{random.choices(string.ascii_uppercase + string.digits, k=32)}', # Random uds token - 'secret': f'secret{random.randint(0, 100)}', # Random secret - 'allow': f'{random.randint(0, 255)}.0.0.0', # Random allow + values, cfg = fixtures.get_config() - } h = hashlib.sha256() h.update(values.get('secret', '').encode()) secret = h.hexdigest() - # Generate an in-memory configuration file from fixtures.TEST_CONFIG - config_file = io.StringIO(fixtures.TEST_CONFIG.format(**values)) - # Read it - cfg = config.read(config_file) # Ensure data is correct self.assertEqual(cfg.pidfile, values['pidfile']) self.assertEqual(cfg.user, values['user']) self.assertEqual(cfg.log_level, values['loglevel']) self.assertEqual(cfg.log_file, values['logfile']) - self.assertEqual(cfg.log_size, values['logsize'] * 1024 * 1024) # Config file is in MB + self.assertEqual( + cfg.log_size, values['logsize'] * 1024 * 1024 + ) # Config file is in MB self.assertEqual(cfg.log_number, values['lognumber']) self.assertEqual(cfg.listen_address, values['address']) self.assertEqual(cfg.workers, values['workers']) @@ -85,8 +61,7 @@ class TestConfigFile(TestCase): self.assertEqual(cfg.ssl_dhparam, values['ssl_dhparam']) self.assertEqual(cfg.uds_server, values['uds_server']) self.assertEqual(cfg.uds_token, values['uds_token']) + self.assertEqual(cfg.uds_timeout, values['uds_timeout']) self.assertEqual(cfg.secret, secret) self.assertEqual(cfg.allow, {values['allow']}) - - - \ No newline at end of file + self.assertEqual(cfg.uds_verify_ssl, values['uds_verify_ssl']) diff --git a/tunnel-server/test/test_tunnel.py b/tunnel-server/test/test_tunnel.py new file mode 100644 index 000000000..de0f10417 --- /dev/null +++ b/tunnel-server/test/test_tunnel.py @@ -0,0 +1,160 @@ +# -*- coding: utf-8 -*- +# +# Copyright (c) 2022 Virtual Cable S.L.U. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of Virtual Cable S.L. nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +''' +Author: Adolfo Gómez, dkmaster at dkmon dot com +''' +import string +import random +import aiohttp + +from unittest import IsolatedAsyncioTestCase, mock + +from uds_tunnel import proxy, tunnel, consts + +from . import fixtures +from .utils import tools + +NOTIFY_TICKET = '0123456789cdef01456789abcdebcdef0123456789abcdef' +UDS_GET_TICKET_RESPONSE = { + 'host': '127.0.0.1', + 'port': 54876, + 'notify': NOTIFY_TICKET, +} +CALLER_HOST = ('host', 12345) +REMOTE_HOST = ('127.0.0.1', 54876) + + +class TestTunnel(IsolatedAsyncioTestCase): + async def test_get_ticket_from_uds(self) -> None: + _, cfg = fixtures.get_config() + # Test some invalid tickets + # Valid ticket are consts.TICKET_LENGTH bytes long, and must be A-Z, a-z, 0-9 + with mock.patch( + 'uds_tunnel.tunnel.TunnelProtocol._readFromUDS', + new_callable=tools.AsyncMock, + ) as m: + m.return_value = UDS_GET_TICKET_RESPONSE + for i in range(0, 100): + ticket = ''.join( + random.choices( + string.ascii_letters + string.digits, k=i % consts.TICKET_LENGTH + ) + ) + + with self.assertRaises(ValueError): + await tunnel.TunnelProtocol.getTicketFromUDS( + cfg, ticket.encode(), CALLER_HOST + ) + + ticket = NOTIFY_TICKET # Samle ticket + for i in range(0, 100): + # Now some requests with valid tickets + # Ensure no exception is raised + ret_value = await tunnel.TunnelProtocol.getTicketFromUDS( + cfg, ticket.encode(), CALLER_HOST + ) + # Ensure data returned is correct {host, port, notify} from mock + self.assertEqual(ret_value, m.return_value) + # Ensure mock was called with correct parameters + print(m.call_args) + # Check calling parameters, first one is the config, second one is the ticket, third one is the caller host + # no kwargs are used + self.assertEqual(m.call_args[0][0], cfg) + self.assertEqual( + m.call_args[0][1], NOTIFY_TICKET.encode() + ) # Same ticket, but bytes + self.assertEqual(m.call_args[0][2], CALLER_HOST[0]) + + print(ret_value) + + # mock should have been called 100 times + self.assertEqual(m.call_count, 100) + + async def test_notify_end_to_uds(self) -> None: + _, cfg = fixtures.get_config() + with mock.patch( + 'uds_tunnel.tunnel.TunnelProtocol._readFromUDS', + new_callable=tools.AsyncMock, + ) as m: + m.return_value = {} + counter = mock.MagicMock() + counter.sent = 123456789 + counter.recv = 987654321 + + ticket = NOTIFY_TICKET.encode() + for i in range(0, 100): + await tunnel.TunnelProtocol.notifyEndToUds(cfg, ticket, counter) + + self.assertEqual(m.call_args[0][0], cfg) + self.assertEqual( + m.call_args[0][1], NOTIFY_TICKET.encode() + ) # Same ticket, but bytes + self.assertEqual(m.call_args[0][2], 'stop') + self.assertEqual( + m.call_args[0][3], + {'sent': str(counter.sent), 'recv': str(counter.recv)}, + ) + + # mock should have been called 100 times + self.assertEqual(m.call_count, 100) + + async def test_read_from_uds(self) -> None: + # Generate a listening http server for testing UDS + # Tesst fine responses: + for use_ssl in (True, False): + async with tools.AsyncHttpServer( + port=13579, response=b'{"result":"ok"}', use_ssl=use_ssl + ) as server: + # Get server configuration, and ensure server is running fine + fake_uds_server = ( + f'http{"s" if use_ssl else ""}://127.0.0.1:{server.port}/' + ) + _, cfg = fixtures.get_config( + uds_server=fake_uds_server, uds_verify_ssl=False + ) + self.assertEqual( + await TestTunnel.get(fake_uds_server), + '{"result":"ok"}', + ) + # Now, tests _readFromUDS + for i in range(100): + ret = await tunnel.TunnelProtocol._readFromUDS( + cfg, NOTIFY_TICKET.encode(), 'test', {'param': 'value'} + ) + self.assertEqual(ret, {'result': 'ok'}) + + # Helpers + @staticmethod + async def get(url: str) -> str: + async with aiohttp.ClientSession() as session: + options = { + 'ssl': False, + } + async with session.get(url, **options) as r: + r.raise_for_status() + return await r.text() diff --git a/tunnel-server/test/utils/__init__.py b/tunnel-server/test/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tunnel-server/test/utils/certs.py b/tunnel-server/test/utils/certs.py new file mode 100644 index 000000000..5f2cd0630 --- /dev/null +++ b/tunnel-server/test/utils/certs.py @@ -0,0 +1,52 @@ +import secrets +import random +from datetime import datetime, timedelta +import ipaddress +import typing + +from cryptography import x509 +from cryptography.x509.oid import NameOID +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa + + +def selfSignedCert(ip: str) -> typing.Tuple[str, str, str]: + key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + backend=default_backend(), + ) + # Create a random password for private key + password = secrets.token_urlsafe(32) + + name = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, ip)]) + san = x509.SubjectAlternativeName([x509.IPAddress(ipaddress.ip_address(ip))]) + + basic_contraints = x509.BasicConstraints(ca=True, path_length=0) + now = datetime.utcnow() + cert = ( + x509.CertificateBuilder() + .subject_name(name) + .issuer_name(name) # self signed, its Issuer DN must match its Subject DN. + .public_key(key.public_key()) + .serial_number(random.SystemRandom().randint(0, 1 << 64)) + .not_valid_before(now) + .not_valid_after(now + timedelta(days=10 * 365)) + .add_extension(basic_contraints, False) + .add_extension(san, False) + .sign(key, hashes.SHA256(), default_backend()) + ) + + return ( + key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.BestAvailableEncryption( + password.encode() + ), + ).decode(), + cert.public_bytes(encoding=serialization.Encoding.PEM).decode(), + password, + ) diff --git a/tunnel-server/test/utils/tools.py b/tunnel-server/test/utils/tools.py new file mode 100644 index 000000000..6ce2952f9 --- /dev/null +++ b/tunnel-server/test/utils/tools.py @@ -0,0 +1,109 @@ +# -*- coding: utf-8 -*- +# +# Copyright (c) 2022 Virtual Cable S.L.U. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of Virtual Cable S.L. nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +''' +Author: Adolfo Gómez, dkmaster at dkmon dot com +''' +import asyncio +import os +import ssl +import typing +import tempfile +from unittest import mock + +from . import certs + +class AsyncMock(mock.MagicMock): + async def __call__(self, *args, **kwargs): + return super().__call__(*args, **kwargs) + + +# simple async http server, will return 200 OK with the request path as body +class AsyncHttpServer: + port: int + _server: typing.Optional[asyncio.AbstractServer] + _response: typing.Optional[bytes] + _ssl_ctx: typing.Optional[ssl.SSLContext] + + def __init__( + self, port: int, *, response: typing.Optional[bytes] = None, use_ssl: bool = False, + host: str = '127.0.0.1' # ip + ): + self.port = port + self._server = None + self._response = response + if use_ssl: + # First, create server cert and key on temp dir + tmpdir = tempfile.gettempdir() + cert, key, password = certs.selfSignedCert('127.0.0.1') + with open(f'{tmpdir}/tmp_cert.pem', 'w') as f: + f.write(key) + f.write(cert) + # Create SSL context + self._ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + self._ssl_ctx.options |= ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 + self._ssl_ctx.load_cert_chain(certfile=f'{tmpdir}/tmp_cert.pem', password=password) + self._ssl_ctx.check_hostname = False + self._ssl_ctx.set_ciphers( + 'ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384' + ) + else: + self._ssl_ctx = None + + # on end, remove certs + def __del__(self): + tmpdir = tempfile.gettempdir() + # os.remove(f'{tmpdir}/tmp_cert.pem') + + async def _handle(self, reader, writer): + data = await reader.read(2048) + path: bytes = data.split()[1] + if self._response is not None: + path = self._response + writer.write( + b'HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: %d\r\n\r\n%s' + % (len(path), path) + ) + await writer.drain() + + async def __aenter__(self): + if self._ssl_ctx is not None: + self._server = await asyncio.start_server( + self._handle, '127.0.0.1', self.port, ssl=self._ssl_ctx + ) + else: + self._server = await asyncio.start_server( + self._handle, '127.0.0.1', self.port + ) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if self._server is not None: + self._server.close() + await self._server.wait_closed() + self._server = None +