1
0
mirror of https://github.com/dkmstr/openuds.git synced 2024-12-23 17:34:17 +03:00

Adding tests and improving tunnel server

This commit is contained in:
Adolfo Gómez García 2022-12-14 16:09:33 +01:00
parent f12ce12155
commit 36fca66c9a
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
9 changed files with 422 additions and 48 deletions

View File

@ -60,6 +60,8 @@ class ConfigurationType(typing.NamedTuple):
uds_server: str
uds_token: str
uds_timeout: int
uds_verify_ssl: bool
secret: str
allow: typing.Set[str]
@ -122,13 +124,15 @@ def read(
ssl_dhparam=uds.get('ssl_dhparam'),
uds_server=uds_server,
uds_token=uds.get('uds_token', 'unauthorized'),
uds_timeout=int(uds.get('uds_timeout', '10')),
uds_verify_ssl=uds.get('uds_verify_ssl', 'true').lower() == 'true',
secret=secret,
allow=set(uds.get('allow', '127.0.0.1').split(',')),
use_uvloop=uds.get('use_uvloop', 'true').lower() == 'true',
)
except ValueError as e:
raise Exception(
f'Mandatory configuration file in incorrect format: {e.args[0]}. Please, revise {CONFIGFILE}'
f'Mandatory configuration file in incorrect format: {e.args[0]}. Please, revise {CONFIGFILE}'
)
except KeyError as e:
raise Exception(

View File

@ -55,6 +55,8 @@ class Proxy:
async def __call__(self, source: socket.socket, context: 'ssl.SSLContext') -> None:
try:
await self.proxy(source, context)
except asyncio.CancelledError:
pass # Return on cancel
except Exception as e:
# get source ip address
try:
@ -69,11 +71,15 @@ class Proxy:
# the protocol controller do the rest
# Upgrade connection to SSL, and use asyncio to handle the rest
transport: 'asyncio.transports.Transport'
protocol: tunnel.TunnelProtocol
(transport, protocol) = await loop.connect_accepted_socket( # type: ignore
lambda: tunnel.TunnelProtocol(self), source, ssl=context
)
try:
protocol: tunnel.TunnelProtocol
# (connect accepted loop not present on AbastractEventLoop definition < 3.10)
(_, protocol) = await loop.connect_accepted_socket( # type: ignore
lambda: tunnel.TunnelProtocol(self), source, ssl=context
)
await protocol.finished
except asyncio.CancelledError:
pass # Return on cancel
await protocol.finished
return

View File

@ -52,7 +52,7 @@ class TunnelProtocol(asyncio.Protocol):
transport: 'asyncio.transports.Transport'
other_side: 'TunnelProtocol'
# Current state
runner: typing.Any # In fact, typing.Callable[[bytes], None], but mypy complains on its check
runner: typing.Any # In fact, typing.Callable[[bytes], None], but mypy complains on checking variables that are callables on classes
# Command buffer
cmd: bytes
# Ticket
@ -91,7 +91,7 @@ class TunnelProtocol(asyncio.Protocol):
self.source = ('', 0)
self.destination = ('', 0)
def process_open(self):
def process_open(self) -> None:
# Open Command has the ticket behind it
if len(self.cmd) < consts.TICKET_LENGTH + consts.COMMAND_LENGTH:
@ -275,7 +275,7 @@ class TunnelProtocol(asyncio.Protocol):
logger.info('TERMINATED %s', self.pretty_source())
@staticmethod
async def _getUdsUrl(
async def _readFromUDS(
cfg: config.ConfigurationType,
ticket: bytes,
msg: str,
@ -289,10 +289,14 @@ class TunnelProtocol(asyncio.Protocol):
url += '?' + '&'.join(
[f'{key}={value}' for key, value in queryParams.items()]
)
# Set options
options: typing.Dict[str, typing.Any] = {'timeout': cfg.uds_timeout}
if cfg.uds_verify_ssl is False:
options['ssl'] = False
# Requests url with aiohttp
async with aiohttp.ClientSession() as session:
async with session.get(url) as r:
async with session.get(url, **options) as r:
if not r.ok:
raise Exception(await r.text())
return await r.json()
@ -305,7 +309,7 @@ class TunnelProtocol(asyncio.Protocol):
) -> typing.MutableMapping[str, typing.Any]:
# Sanity checks
if len(ticket) != consts.TICKET_LENGTH:
raise Exception(f'TICKET INVALID (len={len(ticket)})')
raise ValueError(f'TICKET INVALID (len={len(ticket)})')
for n, i in enumerate(ticket.decode(errors='ignore')):
if (
@ -314,15 +318,15 @@ class TunnelProtocol(asyncio.Protocol):
or (i >= 'A' and i <= 'Z')
):
continue # Correctus
raise Exception(f'TICKET INVALID (char {i} at pos {n})')
raise ValueError(f'TICKET INVALID (char {i} at pos {n})')
return await TunnelProtocol._getUdsUrl(cfg, ticket, address[0])
return await TunnelProtocol._readFromUDS(cfg, ticket, address[0])
@staticmethod
async def notifyEndToUds(
cfg: config.ConfigurationType, ticket: bytes, counter: stats.Stats
) -> None:
await TunnelProtocol._getUdsUrl(
await TunnelProtocol._readFromUDS(
cfg,
ticket,
'stop',

View File

@ -1,3 +1,39 @@
# -*- coding: utf-8 -*-
#
# Copyright (c) 2022 Virtual Cable S.L.U.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of Virtual Cable S.L. nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
'''
Author: Adolfo Gómez, dkmaster at dkmon dot com
'''
import typing
import io
import string
import random
from uds_tunnel import config
TEST_CONFIG='''# Sample UDS tunnel configuration
@ -41,6 +77,8 @@ ssl_dhparam = {ssl_dhparam}
# https://www.example.com:14333/uds/rest/tunnel/ticket
uds_server = {uds_server}
uds_token = {uds_token}
uds_timeout = {uds_timeout}
uds_verify_ssl = {uds_verify_ssl}
# Secret to get access to admin commands (Currently only stats commands). No default for this.
# Admin commands and only allowed from "allow" ips
@ -51,4 +89,30 @@ secret = {secret}
# Only use IPs, no networks allowed
# defaults to localhost (change if listen address is different from 0.0.0.0)
allow = {allow}
'''
'''
def get_config(**overrides) -> typing.Tuple[typing.Mapping[str, typing.Any], config.ConfigurationType]:
values: typing.Dict[str, typing.Any] = {
'pidfile': f'/tmp/uds_tunnel_{random.randint(0, 100)}.pid', # Random pid file
'user': f'user{random.randint(0, 100)}', # Random user
'loglevel': random.choice(['DEBUG', 'INFO', 'WARNING', 'ERROR']), # Random log level
'logfile': f'/tmp/uds_tunnel_{random.randint(0, 100)}.log', # Random log file
'logsize': random.randint(0, 100), # Random log size
'lognumber': random.randint(0, 100), # Random log number
'address': f'{random.randint(0, 255)}.{random.randint(0, 255)}.{random.randint(0, 255)}.{random.randint(0, 255)}', # Random address
'workers': random.randint(1, 100), # Random workers, 0 will return as many as cpu cores
'ssl_certificate': f'/tmp/uds_tunnel_{random.randint(0, 100)}.crt', # Random ssl certificate
'ssl_certificate_key': f'/tmp/uds_tunnel_{random.randint(0, 100)}.key', # Random ssl certificate key
'ssl_ciphers': f'ciphers{random.randint(0, 100)}', # Random ssl ciphers
'ssl_dhparam': f'/tmp/uds_tunnel_{random.randint(0, 100)}.dh', # Random ssl dhparam
'uds_server': f'https://uds_server{random.randint(0, 100)}/some_path', # Random uds server
'uds_token': f'uds_token{"".join(random.choices(string.ascii_uppercase + string.digits, k=32))}', # Random uds token
'uds_timeout': random.randint(0, 100), # Random uds timeout
'uds_verify_ssl': random.choice([True, False]), # Random verify uds ssl
'secret': f'secret{random.randint(0, 100)}', # Random secret
'allow': f'{random.randint(0, 255)}.0.0.0', # Random allow
}
values.update(overrides)
config_file = io.StringIO(TEST_CONFIG.format(**values))
# Read it
return values, config.read(config_file)

View File

@ -28,54 +28,30 @@
'''
Author: Adolfo Gómez, dkmaster at dkmon dot com
'''
import typing
import hashlib
import string
import io
import random
from unittest import TestCase
from uds_tunnel import config
from . import fixtures
class TestConfigFile(TestCase):
def test_config_file(self) -> None:
# Test in-memory configuration files ramdomly created
for _ in range(100):
values: typing.Mapping[str, typing.Any] = {
'pidfile': f'/tmp/uds_tunnel_{random.randint(0, 100)}.pid', # Random pid file
'user': f'user{random.randint(0, 100)}', # Random user
'loglevel': random.choice(['DEBUG', 'INFO', 'WARNING', 'ERROR']), # Random log level
'logfile': f'/tmp/uds_tunnel_{random.randint(0, 100)}.log', # Random log file
'logsize': random.randint(0, 100), # Random log size
'lognumber': random.randint(0, 100), # Random log number
'address': f'{random.randint(0, 255)}.{random.randint(0, 255)}.{random.randint(0, 255)}.{random.randint(0, 255)}', # Random address
'workers': random.randint(1, 100), # Random workers, 0 will return as many as cpu cores
'ssl_certificate': f'/tmp/uds_tunnel_{random.randint(0, 100)}.crt', # Random ssl certificate
'ssl_certificate_key': f'/tmp/uds_tunnel_{random.randint(0, 100)}.key', # Random ssl certificate key
'ssl_ciphers': f'ciphers{random.randint(0, 100)}', # Random ssl ciphers
'ssl_dhparam': f'/tmp/uds_tunnel_{random.randint(0, 100)}.dh', # Random ssl dhparam
'uds_server': f'https://uds_server{random.randint(0, 100)}/some_path', # Random uds server
'uds_token': f'uds_token{random.choices(string.ascii_uppercase + string.digits, k=32)}', # Random uds token
'secret': f'secret{random.randint(0, 100)}', # Random secret
'allow': f'{random.randint(0, 255)}.0.0.0', # Random allow
values, cfg = fixtures.get_config()
}
h = hashlib.sha256()
h.update(values.get('secret', '').encode())
secret = h.hexdigest()
# Generate an in-memory configuration file from fixtures.TEST_CONFIG
config_file = io.StringIO(fixtures.TEST_CONFIG.format(**values))
# Read it
cfg = config.read(config_file)
# Ensure data is correct
self.assertEqual(cfg.pidfile, values['pidfile'])
self.assertEqual(cfg.user, values['user'])
self.assertEqual(cfg.log_level, values['loglevel'])
self.assertEqual(cfg.log_file, values['logfile'])
self.assertEqual(cfg.log_size, values['logsize'] * 1024 * 1024) # Config file is in MB
self.assertEqual(
cfg.log_size, values['logsize'] * 1024 * 1024
) # Config file is in MB
self.assertEqual(cfg.log_number, values['lognumber'])
self.assertEqual(cfg.listen_address, values['address'])
self.assertEqual(cfg.workers, values['workers'])
@ -85,8 +61,7 @@ class TestConfigFile(TestCase):
self.assertEqual(cfg.ssl_dhparam, values['ssl_dhparam'])
self.assertEqual(cfg.uds_server, values['uds_server'])
self.assertEqual(cfg.uds_token, values['uds_token'])
self.assertEqual(cfg.uds_timeout, values['uds_timeout'])
self.assertEqual(cfg.secret, secret)
self.assertEqual(cfg.allow, {values['allow']})
self.assertEqual(cfg.uds_verify_ssl, values['uds_verify_ssl'])

View File

@ -0,0 +1,160 @@
# -*- coding: utf-8 -*-
#
# Copyright (c) 2022 Virtual Cable S.L.U.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of Virtual Cable S.L. nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
'''
Author: Adolfo Gómez, dkmaster at dkmon dot com
'''
import string
import random
import aiohttp
from unittest import IsolatedAsyncioTestCase, mock
from uds_tunnel import proxy, tunnel, consts
from . import fixtures
from .utils import tools
NOTIFY_TICKET = '0123456789cdef01456789abcdebcdef0123456789abcdef'
UDS_GET_TICKET_RESPONSE = {
'host': '127.0.0.1',
'port': 54876,
'notify': NOTIFY_TICKET,
}
CALLER_HOST = ('host', 12345)
REMOTE_HOST = ('127.0.0.1', 54876)
class TestTunnel(IsolatedAsyncioTestCase):
async def test_get_ticket_from_uds(self) -> None:
_, cfg = fixtures.get_config()
# Test some invalid tickets
# Valid ticket are consts.TICKET_LENGTH bytes long, and must be A-Z, a-z, 0-9
with mock.patch(
'uds_tunnel.tunnel.TunnelProtocol._readFromUDS',
new_callable=tools.AsyncMock,
) as m:
m.return_value = UDS_GET_TICKET_RESPONSE
for i in range(0, 100):
ticket = ''.join(
random.choices(
string.ascii_letters + string.digits, k=i % consts.TICKET_LENGTH
)
)
with self.assertRaises(ValueError):
await tunnel.TunnelProtocol.getTicketFromUDS(
cfg, ticket.encode(), CALLER_HOST
)
ticket = NOTIFY_TICKET # Samle ticket
for i in range(0, 100):
# Now some requests with valid tickets
# Ensure no exception is raised
ret_value = await tunnel.TunnelProtocol.getTicketFromUDS(
cfg, ticket.encode(), CALLER_HOST
)
# Ensure data returned is correct {host, port, notify} from mock
self.assertEqual(ret_value, m.return_value)
# Ensure mock was called with correct parameters
print(m.call_args)
# Check calling parameters, first one is the config, second one is the ticket, third one is the caller host
# no kwargs are used
self.assertEqual(m.call_args[0][0], cfg)
self.assertEqual(
m.call_args[0][1], NOTIFY_TICKET.encode()
) # Same ticket, but bytes
self.assertEqual(m.call_args[0][2], CALLER_HOST[0])
print(ret_value)
# mock should have been called 100 times
self.assertEqual(m.call_count, 100)
async def test_notify_end_to_uds(self) -> None:
_, cfg = fixtures.get_config()
with mock.patch(
'uds_tunnel.tunnel.TunnelProtocol._readFromUDS',
new_callable=tools.AsyncMock,
) as m:
m.return_value = {}
counter = mock.MagicMock()
counter.sent = 123456789
counter.recv = 987654321
ticket = NOTIFY_TICKET.encode()
for i in range(0, 100):
await tunnel.TunnelProtocol.notifyEndToUds(cfg, ticket, counter)
self.assertEqual(m.call_args[0][0], cfg)
self.assertEqual(
m.call_args[0][1], NOTIFY_TICKET.encode()
) # Same ticket, but bytes
self.assertEqual(m.call_args[0][2], 'stop')
self.assertEqual(
m.call_args[0][3],
{'sent': str(counter.sent), 'recv': str(counter.recv)},
)
# mock should have been called 100 times
self.assertEqual(m.call_count, 100)
async def test_read_from_uds(self) -> None:
# Generate a listening http server for testing UDS
# Tesst fine responses:
for use_ssl in (True, False):
async with tools.AsyncHttpServer(
port=13579, response=b'{"result":"ok"}', use_ssl=use_ssl
) as server:
# Get server configuration, and ensure server is running fine
fake_uds_server = (
f'http{"s" if use_ssl else ""}://127.0.0.1:{server.port}/'
)
_, cfg = fixtures.get_config(
uds_server=fake_uds_server, uds_verify_ssl=False
)
self.assertEqual(
await TestTunnel.get(fake_uds_server),
'{"result":"ok"}',
)
# Now, tests _readFromUDS
for i in range(100):
ret = await tunnel.TunnelProtocol._readFromUDS(
cfg, NOTIFY_TICKET.encode(), 'test', {'param': 'value'}
)
self.assertEqual(ret, {'result': 'ok'})
# Helpers
@staticmethod
async def get(url: str) -> str:
async with aiohttp.ClientSession() as session:
options = {
'ssl': False,
}
async with session.get(url, **options) as r:
r.raise_for_status()
return await r.text()

View File

View File

@ -0,0 +1,52 @@
import secrets
import random
from datetime import datetime, timedelta
import ipaddress
import typing
from cryptography import x509
from cryptography.x509.oid import NameOID
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
def selfSignedCert(ip: str) -> typing.Tuple[str, str, str]:
key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048,
backend=default_backend(),
)
# Create a random password for private key
password = secrets.token_urlsafe(32)
name = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, ip)])
san = x509.SubjectAlternativeName([x509.IPAddress(ipaddress.ip_address(ip))])
basic_contraints = x509.BasicConstraints(ca=True, path_length=0)
now = datetime.utcnow()
cert = (
x509.CertificateBuilder()
.subject_name(name)
.issuer_name(name) # self signed, its Issuer DN must match its Subject DN.
.public_key(key.public_key())
.serial_number(random.SystemRandom().randint(0, 1 << 64))
.not_valid_before(now)
.not_valid_after(now + timedelta(days=10 * 365))
.add_extension(basic_contraints, False)
.add_extension(san, False)
.sign(key, hashes.SHA256(), default_backend())
)
return (
key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.BestAvailableEncryption(
password.encode()
),
).decode(),
cert.public_bytes(encoding=serialization.Encoding.PEM).decode(),
password,
)

View File

@ -0,0 +1,109 @@
# -*- coding: utf-8 -*-
#
# Copyright (c) 2022 Virtual Cable S.L.U.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of Virtual Cable S.L. nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
'''
Author: Adolfo Gómez, dkmaster at dkmon dot com
'''
import asyncio
import os
import ssl
import typing
import tempfile
from unittest import mock
from . import certs
class AsyncMock(mock.MagicMock):
async def __call__(self, *args, **kwargs):
return super().__call__(*args, **kwargs)
# simple async http server, will return 200 OK with the request path as body
class AsyncHttpServer:
port: int
_server: typing.Optional[asyncio.AbstractServer]
_response: typing.Optional[bytes]
_ssl_ctx: typing.Optional[ssl.SSLContext]
def __init__(
self, port: int, *, response: typing.Optional[bytes] = None, use_ssl: bool = False,
host: str = '127.0.0.1' # ip
):
self.port = port
self._server = None
self._response = response
if use_ssl:
# First, create server cert and key on temp dir
tmpdir = tempfile.gettempdir()
cert, key, password = certs.selfSignedCert('127.0.0.1')
with open(f'{tmpdir}/tmp_cert.pem', 'w') as f:
f.write(key)
f.write(cert)
# Create SSL context
self._ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
self._ssl_ctx.options |= ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1
self._ssl_ctx.load_cert_chain(certfile=f'{tmpdir}/tmp_cert.pem', password=password)
self._ssl_ctx.check_hostname = False
self._ssl_ctx.set_ciphers(
'ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384'
)
else:
self._ssl_ctx = None
# on end, remove certs
def __del__(self):
tmpdir = tempfile.gettempdir()
# os.remove(f'{tmpdir}/tmp_cert.pem')
async def _handle(self, reader, writer):
data = await reader.read(2048)
path: bytes = data.split()[1]
if self._response is not None:
path = self._response
writer.write(
b'HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: %d\r\n\r\n%s'
% (len(path), path)
)
await writer.drain()
async def __aenter__(self):
if self._ssl_ctx is not None:
self._server = await asyncio.start_server(
self._handle, '127.0.0.1', self.port, ssl=self._ssl_ctx
)
else:
self._server = await asyncio.start_server(
self._handle, '127.0.0.1', self.port
)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self._server is not None:
self._server.close()
await self._server.wait_closed()
self._server = None