mirror of
https://github.com/dkmstr/openuds.git
synced 2024-12-23 17:34:17 +03:00
Adding tests and improving tunnel server
This commit is contained in:
parent
f12ce12155
commit
36fca66c9a
@ -60,6 +60,8 @@ class ConfigurationType(typing.NamedTuple):
|
||||
|
||||
uds_server: str
|
||||
uds_token: str
|
||||
uds_timeout: int
|
||||
uds_verify_ssl: bool
|
||||
|
||||
secret: str
|
||||
allow: typing.Set[str]
|
||||
@ -122,13 +124,15 @@ def read(
|
||||
ssl_dhparam=uds.get('ssl_dhparam'),
|
||||
uds_server=uds_server,
|
||||
uds_token=uds.get('uds_token', 'unauthorized'),
|
||||
uds_timeout=int(uds.get('uds_timeout', '10')),
|
||||
uds_verify_ssl=uds.get('uds_verify_ssl', 'true').lower() == 'true',
|
||||
secret=secret,
|
||||
allow=set(uds.get('allow', '127.0.0.1').split(',')),
|
||||
use_uvloop=uds.get('use_uvloop', 'true').lower() == 'true',
|
||||
)
|
||||
except ValueError as e:
|
||||
raise Exception(
|
||||
f'Mandatory configuration file in incorrect format: {e.args[0]}. Please, revise {CONFIGFILE}'
|
||||
f'Mandatory configuration file in incorrect format: {e.args[0]}. Please, revise {CONFIGFILE}'
|
||||
)
|
||||
except KeyError as e:
|
||||
raise Exception(
|
||||
|
@ -55,6 +55,8 @@ class Proxy:
|
||||
async def __call__(self, source: socket.socket, context: 'ssl.SSLContext') -> None:
|
||||
try:
|
||||
await self.proxy(source, context)
|
||||
except asyncio.CancelledError:
|
||||
pass # Return on cancel
|
||||
except Exception as e:
|
||||
# get source ip address
|
||||
try:
|
||||
@ -69,11 +71,15 @@ class Proxy:
|
||||
# the protocol controller do the rest
|
||||
|
||||
# Upgrade connection to SSL, and use asyncio to handle the rest
|
||||
transport: 'asyncio.transports.Transport'
|
||||
protocol: tunnel.TunnelProtocol
|
||||
(transport, protocol) = await loop.connect_accepted_socket( # type: ignore
|
||||
lambda: tunnel.TunnelProtocol(self), source, ssl=context
|
||||
)
|
||||
try:
|
||||
protocol: tunnel.TunnelProtocol
|
||||
# (connect accepted loop not present on AbastractEventLoop definition < 3.10)
|
||||
(_, protocol) = await loop.connect_accepted_socket( # type: ignore
|
||||
lambda: tunnel.TunnelProtocol(self), source, ssl=context
|
||||
)
|
||||
|
||||
await protocol.finished
|
||||
except asyncio.CancelledError:
|
||||
pass # Return on cancel
|
||||
|
||||
await protocol.finished
|
||||
return
|
||||
|
@ -52,7 +52,7 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
transport: 'asyncio.transports.Transport'
|
||||
other_side: 'TunnelProtocol'
|
||||
# Current state
|
||||
runner: typing.Any # In fact, typing.Callable[[bytes], None], but mypy complains on its check
|
||||
runner: typing.Any # In fact, typing.Callable[[bytes], None], but mypy complains on checking variables that are callables on classes
|
||||
# Command buffer
|
||||
cmd: bytes
|
||||
# Ticket
|
||||
@ -91,7 +91,7 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
self.source = ('', 0)
|
||||
self.destination = ('', 0)
|
||||
|
||||
def process_open(self):
|
||||
def process_open(self) -> None:
|
||||
# Open Command has the ticket behind it
|
||||
|
||||
if len(self.cmd) < consts.TICKET_LENGTH + consts.COMMAND_LENGTH:
|
||||
@ -275,7 +275,7 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
logger.info('TERMINATED %s', self.pretty_source())
|
||||
|
||||
@staticmethod
|
||||
async def _getUdsUrl(
|
||||
async def _readFromUDS(
|
||||
cfg: config.ConfigurationType,
|
||||
ticket: bytes,
|
||||
msg: str,
|
||||
@ -289,10 +289,14 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
url += '?' + '&'.join(
|
||||
[f'{key}={value}' for key, value in queryParams.items()]
|
||||
)
|
||||
# Set options
|
||||
options: typing.Dict[str, typing.Any] = {'timeout': cfg.uds_timeout}
|
||||
if cfg.uds_verify_ssl is False:
|
||||
options['ssl'] = False
|
||||
# Requests url with aiohttp
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as r:
|
||||
async with session.get(url, **options) as r:
|
||||
if not r.ok:
|
||||
raise Exception(await r.text())
|
||||
return await r.json()
|
||||
@ -305,7 +309,7 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
) -> typing.MutableMapping[str, typing.Any]:
|
||||
# Sanity checks
|
||||
if len(ticket) != consts.TICKET_LENGTH:
|
||||
raise Exception(f'TICKET INVALID (len={len(ticket)})')
|
||||
raise ValueError(f'TICKET INVALID (len={len(ticket)})')
|
||||
|
||||
for n, i in enumerate(ticket.decode(errors='ignore')):
|
||||
if (
|
||||
@ -314,15 +318,15 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
or (i >= 'A' and i <= 'Z')
|
||||
):
|
||||
continue # Correctus
|
||||
raise Exception(f'TICKET INVALID (char {i} at pos {n})')
|
||||
raise ValueError(f'TICKET INVALID (char {i} at pos {n})')
|
||||
|
||||
return await TunnelProtocol._getUdsUrl(cfg, ticket, address[0])
|
||||
return await TunnelProtocol._readFromUDS(cfg, ticket, address[0])
|
||||
|
||||
@staticmethod
|
||||
async def notifyEndToUds(
|
||||
cfg: config.ConfigurationType, ticket: bytes, counter: stats.Stats
|
||||
) -> None:
|
||||
await TunnelProtocol._getUdsUrl(
|
||||
await TunnelProtocol._readFromUDS(
|
||||
cfg,
|
||||
ticket,
|
||||
'stop',
|
||||
|
@ -1,3 +1,39 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright (c) 2022 Virtual Cable S.L.U.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification,
|
||||
# are permitted provided that the following conditions are met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright notice,
|
||||
# this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
# * Neither the name of Virtual Cable S.L. nor the names of its contributors
|
||||
# may be used to endorse or promote products derived from this software
|
||||
# without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
'''
|
||||
Author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
'''
|
||||
import typing
|
||||
import io
|
||||
import string
|
||||
import random
|
||||
|
||||
from uds_tunnel import config
|
||||
|
||||
TEST_CONFIG='''# Sample UDS tunnel configuration
|
||||
|
||||
@ -41,6 +77,8 @@ ssl_dhparam = {ssl_dhparam}
|
||||
# https://www.example.com:14333/uds/rest/tunnel/ticket
|
||||
uds_server = {uds_server}
|
||||
uds_token = {uds_token}
|
||||
uds_timeout = {uds_timeout}
|
||||
uds_verify_ssl = {uds_verify_ssl}
|
||||
|
||||
# Secret to get access to admin commands (Currently only stats commands). No default for this.
|
||||
# Admin commands and only allowed from "allow" ips
|
||||
@ -51,4 +89,30 @@ secret = {secret}
|
||||
# Only use IPs, no networks allowed
|
||||
# defaults to localhost (change if listen address is different from 0.0.0.0)
|
||||
allow = {allow}
|
||||
'''
|
||||
'''
|
||||
|
||||
def get_config(**overrides) -> typing.Tuple[typing.Mapping[str, typing.Any], config.ConfigurationType]:
|
||||
values: typing.Dict[str, typing.Any] = {
|
||||
'pidfile': f'/tmp/uds_tunnel_{random.randint(0, 100)}.pid', # Random pid file
|
||||
'user': f'user{random.randint(0, 100)}', # Random user
|
||||
'loglevel': random.choice(['DEBUG', 'INFO', 'WARNING', 'ERROR']), # Random log level
|
||||
'logfile': f'/tmp/uds_tunnel_{random.randint(0, 100)}.log', # Random log file
|
||||
'logsize': random.randint(0, 100), # Random log size
|
||||
'lognumber': random.randint(0, 100), # Random log number
|
||||
'address': f'{random.randint(0, 255)}.{random.randint(0, 255)}.{random.randint(0, 255)}.{random.randint(0, 255)}', # Random address
|
||||
'workers': random.randint(1, 100), # Random workers, 0 will return as many as cpu cores
|
||||
'ssl_certificate': f'/tmp/uds_tunnel_{random.randint(0, 100)}.crt', # Random ssl certificate
|
||||
'ssl_certificate_key': f'/tmp/uds_tunnel_{random.randint(0, 100)}.key', # Random ssl certificate key
|
||||
'ssl_ciphers': f'ciphers{random.randint(0, 100)}', # Random ssl ciphers
|
||||
'ssl_dhparam': f'/tmp/uds_tunnel_{random.randint(0, 100)}.dh', # Random ssl dhparam
|
||||
'uds_server': f'https://uds_server{random.randint(0, 100)}/some_path', # Random uds server
|
||||
'uds_token': f'uds_token{"".join(random.choices(string.ascii_uppercase + string.digits, k=32))}', # Random uds token
|
||||
'uds_timeout': random.randint(0, 100), # Random uds timeout
|
||||
'uds_verify_ssl': random.choice([True, False]), # Random verify uds ssl
|
||||
'secret': f'secret{random.randint(0, 100)}', # Random secret
|
||||
'allow': f'{random.randint(0, 255)}.0.0.0', # Random allow
|
||||
}
|
||||
values.update(overrides)
|
||||
config_file = io.StringIO(TEST_CONFIG.format(**values))
|
||||
# Read it
|
||||
return values, config.read(config_file)
|
||||
|
@ -28,54 +28,30 @@
|
||||
'''
|
||||
Author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
'''
|
||||
import typing
|
||||
import hashlib
|
||||
import string
|
||||
import io
|
||||
import random
|
||||
|
||||
from unittest import TestCase
|
||||
|
||||
from uds_tunnel import config
|
||||
|
||||
from . import fixtures
|
||||
|
||||
|
||||
class TestConfigFile(TestCase):
|
||||
def test_config_file(self) -> None:
|
||||
# Test in-memory configuration files ramdomly created
|
||||
for _ in range(100):
|
||||
values: typing.Mapping[str, typing.Any] = {
|
||||
'pidfile': f'/tmp/uds_tunnel_{random.randint(0, 100)}.pid', # Random pid file
|
||||
'user': f'user{random.randint(0, 100)}', # Random user
|
||||
'loglevel': random.choice(['DEBUG', 'INFO', 'WARNING', 'ERROR']), # Random log level
|
||||
'logfile': f'/tmp/uds_tunnel_{random.randint(0, 100)}.log', # Random log file
|
||||
'logsize': random.randint(0, 100), # Random log size
|
||||
'lognumber': random.randint(0, 100), # Random log number
|
||||
'address': f'{random.randint(0, 255)}.{random.randint(0, 255)}.{random.randint(0, 255)}.{random.randint(0, 255)}', # Random address
|
||||
'workers': random.randint(1, 100), # Random workers, 0 will return as many as cpu cores
|
||||
'ssl_certificate': f'/tmp/uds_tunnel_{random.randint(0, 100)}.crt', # Random ssl certificate
|
||||
'ssl_certificate_key': f'/tmp/uds_tunnel_{random.randint(0, 100)}.key', # Random ssl certificate key
|
||||
'ssl_ciphers': f'ciphers{random.randint(0, 100)}', # Random ssl ciphers
|
||||
'ssl_dhparam': f'/tmp/uds_tunnel_{random.randint(0, 100)}.dh', # Random ssl dhparam
|
||||
'uds_server': f'https://uds_server{random.randint(0, 100)}/some_path', # Random uds server
|
||||
'uds_token': f'uds_token{random.choices(string.ascii_uppercase + string.digits, k=32)}', # Random uds token
|
||||
'secret': f'secret{random.randint(0, 100)}', # Random secret
|
||||
'allow': f'{random.randint(0, 255)}.0.0.0', # Random allow
|
||||
values, cfg = fixtures.get_config()
|
||||
|
||||
}
|
||||
h = hashlib.sha256()
|
||||
h.update(values.get('secret', '').encode())
|
||||
secret = h.hexdigest()
|
||||
# Generate an in-memory configuration file from fixtures.TEST_CONFIG
|
||||
config_file = io.StringIO(fixtures.TEST_CONFIG.format(**values))
|
||||
# Read it
|
||||
cfg = config.read(config_file)
|
||||
# Ensure data is correct
|
||||
self.assertEqual(cfg.pidfile, values['pidfile'])
|
||||
self.assertEqual(cfg.user, values['user'])
|
||||
self.assertEqual(cfg.log_level, values['loglevel'])
|
||||
self.assertEqual(cfg.log_file, values['logfile'])
|
||||
self.assertEqual(cfg.log_size, values['logsize'] * 1024 * 1024) # Config file is in MB
|
||||
self.assertEqual(
|
||||
cfg.log_size, values['logsize'] * 1024 * 1024
|
||||
) # Config file is in MB
|
||||
self.assertEqual(cfg.log_number, values['lognumber'])
|
||||
self.assertEqual(cfg.listen_address, values['address'])
|
||||
self.assertEqual(cfg.workers, values['workers'])
|
||||
@ -85,8 +61,7 @@ class TestConfigFile(TestCase):
|
||||
self.assertEqual(cfg.ssl_dhparam, values['ssl_dhparam'])
|
||||
self.assertEqual(cfg.uds_server, values['uds_server'])
|
||||
self.assertEqual(cfg.uds_token, values['uds_token'])
|
||||
self.assertEqual(cfg.uds_timeout, values['uds_timeout'])
|
||||
self.assertEqual(cfg.secret, secret)
|
||||
self.assertEqual(cfg.allow, {values['allow']})
|
||||
|
||||
|
||||
|
||||
self.assertEqual(cfg.uds_verify_ssl, values['uds_verify_ssl'])
|
||||
|
160
tunnel-server/test/test_tunnel.py
Normal file
160
tunnel-server/test/test_tunnel.py
Normal file
@ -0,0 +1,160 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright (c) 2022 Virtual Cable S.L.U.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification,
|
||||
# are permitted provided that the following conditions are met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright notice,
|
||||
# this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
# * Neither the name of Virtual Cable S.L. nor the names of its contributors
|
||||
# may be used to endorse or promote products derived from this software
|
||||
# without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
'''
|
||||
Author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
'''
|
||||
import string
|
||||
import random
|
||||
import aiohttp
|
||||
|
||||
from unittest import IsolatedAsyncioTestCase, mock
|
||||
|
||||
from uds_tunnel import proxy, tunnel, consts
|
||||
|
||||
from . import fixtures
|
||||
from .utils import tools
|
||||
|
||||
NOTIFY_TICKET = '0123456789cdef01456789abcdebcdef0123456789abcdef'
|
||||
UDS_GET_TICKET_RESPONSE = {
|
||||
'host': '127.0.0.1',
|
||||
'port': 54876,
|
||||
'notify': NOTIFY_TICKET,
|
||||
}
|
||||
CALLER_HOST = ('host', 12345)
|
||||
REMOTE_HOST = ('127.0.0.1', 54876)
|
||||
|
||||
|
||||
class TestTunnel(IsolatedAsyncioTestCase):
|
||||
async def test_get_ticket_from_uds(self) -> None:
|
||||
_, cfg = fixtures.get_config()
|
||||
# Test some invalid tickets
|
||||
# Valid ticket are consts.TICKET_LENGTH bytes long, and must be A-Z, a-z, 0-9
|
||||
with mock.patch(
|
||||
'uds_tunnel.tunnel.TunnelProtocol._readFromUDS',
|
||||
new_callable=tools.AsyncMock,
|
||||
) as m:
|
||||
m.return_value = UDS_GET_TICKET_RESPONSE
|
||||
for i in range(0, 100):
|
||||
ticket = ''.join(
|
||||
random.choices(
|
||||
string.ascii_letters + string.digits, k=i % consts.TICKET_LENGTH
|
||||
)
|
||||
)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
await tunnel.TunnelProtocol.getTicketFromUDS(
|
||||
cfg, ticket.encode(), CALLER_HOST
|
||||
)
|
||||
|
||||
ticket = NOTIFY_TICKET # Samle ticket
|
||||
for i in range(0, 100):
|
||||
# Now some requests with valid tickets
|
||||
# Ensure no exception is raised
|
||||
ret_value = await tunnel.TunnelProtocol.getTicketFromUDS(
|
||||
cfg, ticket.encode(), CALLER_HOST
|
||||
)
|
||||
# Ensure data returned is correct {host, port, notify} from mock
|
||||
self.assertEqual(ret_value, m.return_value)
|
||||
# Ensure mock was called with correct parameters
|
||||
print(m.call_args)
|
||||
# Check calling parameters, first one is the config, second one is the ticket, third one is the caller host
|
||||
# no kwargs are used
|
||||
self.assertEqual(m.call_args[0][0], cfg)
|
||||
self.assertEqual(
|
||||
m.call_args[0][1], NOTIFY_TICKET.encode()
|
||||
) # Same ticket, but bytes
|
||||
self.assertEqual(m.call_args[0][2], CALLER_HOST[0])
|
||||
|
||||
print(ret_value)
|
||||
|
||||
# mock should have been called 100 times
|
||||
self.assertEqual(m.call_count, 100)
|
||||
|
||||
async def test_notify_end_to_uds(self) -> None:
|
||||
_, cfg = fixtures.get_config()
|
||||
with mock.patch(
|
||||
'uds_tunnel.tunnel.TunnelProtocol._readFromUDS',
|
||||
new_callable=tools.AsyncMock,
|
||||
) as m:
|
||||
m.return_value = {}
|
||||
counter = mock.MagicMock()
|
||||
counter.sent = 123456789
|
||||
counter.recv = 987654321
|
||||
|
||||
ticket = NOTIFY_TICKET.encode()
|
||||
for i in range(0, 100):
|
||||
await tunnel.TunnelProtocol.notifyEndToUds(cfg, ticket, counter)
|
||||
|
||||
self.assertEqual(m.call_args[0][0], cfg)
|
||||
self.assertEqual(
|
||||
m.call_args[0][1], NOTIFY_TICKET.encode()
|
||||
) # Same ticket, but bytes
|
||||
self.assertEqual(m.call_args[0][2], 'stop')
|
||||
self.assertEqual(
|
||||
m.call_args[0][3],
|
||||
{'sent': str(counter.sent), 'recv': str(counter.recv)},
|
||||
)
|
||||
|
||||
# mock should have been called 100 times
|
||||
self.assertEqual(m.call_count, 100)
|
||||
|
||||
async def test_read_from_uds(self) -> None:
|
||||
# Generate a listening http server for testing UDS
|
||||
# Tesst fine responses:
|
||||
for use_ssl in (True, False):
|
||||
async with tools.AsyncHttpServer(
|
||||
port=13579, response=b'{"result":"ok"}', use_ssl=use_ssl
|
||||
) as server:
|
||||
# Get server configuration, and ensure server is running fine
|
||||
fake_uds_server = (
|
||||
f'http{"s" if use_ssl else ""}://127.0.0.1:{server.port}/'
|
||||
)
|
||||
_, cfg = fixtures.get_config(
|
||||
uds_server=fake_uds_server, uds_verify_ssl=False
|
||||
)
|
||||
self.assertEqual(
|
||||
await TestTunnel.get(fake_uds_server),
|
||||
'{"result":"ok"}',
|
||||
)
|
||||
# Now, tests _readFromUDS
|
||||
for i in range(100):
|
||||
ret = await tunnel.TunnelProtocol._readFromUDS(
|
||||
cfg, NOTIFY_TICKET.encode(), 'test', {'param': 'value'}
|
||||
)
|
||||
self.assertEqual(ret, {'result': 'ok'})
|
||||
|
||||
# Helpers
|
||||
@staticmethod
|
||||
async def get(url: str) -> str:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
options = {
|
||||
'ssl': False,
|
||||
}
|
||||
async with session.get(url, **options) as r:
|
||||
r.raise_for_status()
|
||||
return await r.text()
|
0
tunnel-server/test/utils/__init__.py
Normal file
0
tunnel-server/test/utils/__init__.py
Normal file
52
tunnel-server/test/utils/certs.py
Normal file
52
tunnel-server/test/utils/certs.py
Normal file
@ -0,0 +1,52 @@
|
||||
import secrets
|
||||
import random
|
||||
from datetime import datetime, timedelta
|
||||
import ipaddress
|
||||
import typing
|
||||
|
||||
from cryptography import x509
|
||||
from cryptography.x509.oid import NameOID
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
|
||||
|
||||
def selfSignedCert(ip: str) -> typing.Tuple[str, str, str]:
|
||||
key = rsa.generate_private_key(
|
||||
public_exponent=65537,
|
||||
key_size=2048,
|
||||
backend=default_backend(),
|
||||
)
|
||||
# Create a random password for private key
|
||||
password = secrets.token_urlsafe(32)
|
||||
|
||||
name = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, ip)])
|
||||
san = x509.SubjectAlternativeName([x509.IPAddress(ipaddress.ip_address(ip))])
|
||||
|
||||
basic_contraints = x509.BasicConstraints(ca=True, path_length=0)
|
||||
now = datetime.utcnow()
|
||||
cert = (
|
||||
x509.CertificateBuilder()
|
||||
.subject_name(name)
|
||||
.issuer_name(name) # self signed, its Issuer DN must match its Subject DN.
|
||||
.public_key(key.public_key())
|
||||
.serial_number(random.SystemRandom().randint(0, 1 << 64))
|
||||
.not_valid_before(now)
|
||||
.not_valid_after(now + timedelta(days=10 * 365))
|
||||
.add_extension(basic_contraints, False)
|
||||
.add_extension(san, False)
|
||||
.sign(key, hashes.SHA256(), default_backend())
|
||||
)
|
||||
|
||||
return (
|
||||
key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
||||
encryption_algorithm=serialization.BestAvailableEncryption(
|
||||
password.encode()
|
||||
),
|
||||
).decode(),
|
||||
cert.public_bytes(encoding=serialization.Encoding.PEM).decode(),
|
||||
password,
|
||||
)
|
109
tunnel-server/test/utils/tools.py
Normal file
109
tunnel-server/test/utils/tools.py
Normal file
@ -0,0 +1,109 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright (c) 2022 Virtual Cable S.L.U.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification,
|
||||
# are permitted provided that the following conditions are met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright notice,
|
||||
# this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
# * Neither the name of Virtual Cable S.L. nor the names of its contributors
|
||||
# may be used to endorse or promote products derived from this software
|
||||
# without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
'''
|
||||
Author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
'''
|
||||
import asyncio
|
||||
import os
|
||||
import ssl
|
||||
import typing
|
||||
import tempfile
|
||||
from unittest import mock
|
||||
|
||||
from . import certs
|
||||
|
||||
class AsyncMock(mock.MagicMock):
|
||||
async def __call__(self, *args, **kwargs):
|
||||
return super().__call__(*args, **kwargs)
|
||||
|
||||
|
||||
# simple async http server, will return 200 OK with the request path as body
|
||||
class AsyncHttpServer:
|
||||
port: int
|
||||
_server: typing.Optional[asyncio.AbstractServer]
|
||||
_response: typing.Optional[bytes]
|
||||
_ssl_ctx: typing.Optional[ssl.SSLContext]
|
||||
|
||||
def __init__(
|
||||
self, port: int, *, response: typing.Optional[bytes] = None, use_ssl: bool = False,
|
||||
host: str = '127.0.0.1' # ip
|
||||
):
|
||||
self.port = port
|
||||
self._server = None
|
||||
self._response = response
|
||||
if use_ssl:
|
||||
# First, create server cert and key on temp dir
|
||||
tmpdir = tempfile.gettempdir()
|
||||
cert, key, password = certs.selfSignedCert('127.0.0.1')
|
||||
with open(f'{tmpdir}/tmp_cert.pem', 'w') as f:
|
||||
f.write(key)
|
||||
f.write(cert)
|
||||
# Create SSL context
|
||||
self._ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
self._ssl_ctx.options |= ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1
|
||||
self._ssl_ctx.load_cert_chain(certfile=f'{tmpdir}/tmp_cert.pem', password=password)
|
||||
self._ssl_ctx.check_hostname = False
|
||||
self._ssl_ctx.set_ciphers(
|
||||
'ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384'
|
||||
)
|
||||
else:
|
||||
self._ssl_ctx = None
|
||||
|
||||
# on end, remove certs
|
||||
def __del__(self):
|
||||
tmpdir = tempfile.gettempdir()
|
||||
# os.remove(f'{tmpdir}/tmp_cert.pem')
|
||||
|
||||
async def _handle(self, reader, writer):
|
||||
data = await reader.read(2048)
|
||||
path: bytes = data.split()[1]
|
||||
if self._response is not None:
|
||||
path = self._response
|
||||
writer.write(
|
||||
b'HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: %d\r\n\r\n%s'
|
||||
% (len(path), path)
|
||||
)
|
||||
await writer.drain()
|
||||
|
||||
async def __aenter__(self):
|
||||
if self._ssl_ctx is not None:
|
||||
self._server = await asyncio.start_server(
|
||||
self._handle, '127.0.0.1', self.port, ssl=self._ssl_ctx
|
||||
)
|
||||
else:
|
||||
self._server = await asyncio.start_server(
|
||||
self._handle, '127.0.0.1', self.port
|
||||
)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self._server is not None:
|
||||
self._server.close()
|
||||
await self._server.wait_closed()
|
||||
self._server = None
|
||||
|
Loading…
Reference in New Issue
Block a user