1
0
mirror of https://github.com/dkmstr/openuds.git synced 2025-03-13 08:58:35 +03:00

improved tunnel server and tunnel server tests

This commit is contained in:
Adolfo Gómez García 2022-12-18 03:42:19 +01:00
parent 370799912f
commit 0f5f3df3f0
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
12 changed files with 98 additions and 77 deletions

View File

@ -126,9 +126,7 @@ async def main():
data = client.recv(4)
print(data)
# Upgrade connection to SSL, and use asyncio to handle the rest
transport: 'asyncio.transports.Transport'
protocol: TunnelProtocol
(transport, protocol) = await loop.connect_accepted_socket( # type: ignore
(_, protocol) = await loop.connect_accepted_socket( # type: ignore
lambda: TunnelProtocol(), client, ssl=context
)

View File

@ -43,14 +43,15 @@ class ConfigurationType(typing.NamedTuple):
pidfile: str
user: str
log_level: str
log_file: str
log_size: int
log_number: int
loglevel: str
logfile: str
logsize: int
lognumber: int
listen_address: str
listen_port: int
listen_ipv6: bool
ipv6: bool
workers: int
@ -113,13 +114,13 @@ def read(
return ConfigurationType(
pidfile=uds.get('pidfile', ''),
user=uds.get('user', ''),
log_level=uds.get('loglevel', 'ERROR'),
log_file=uds.get('logfile', ''),
log_size=int(logsize) * 1024 * 1024,
log_number=int(uds.get('lognumber', '3')),
loglevel=uds.get('loglevel', 'ERROR'),
logfile=uds.get('logfile', ''),
logsize=int(logsize) * 1024 * 1024,
lognumber=int(uds.get('lognumber', '3')),
listen_address=uds.get('address', '0.0.0.0'),
listen_port=int(uds.get('port', '443')),
listen_ipv6=uds.get('ipv6', 'false').lower() == 'true',
ipv6=uds.get('ipv6', 'false').lower() == 'true',
workers=int(uds.get('workers', '0')) or multiprocessing.cpu_count(),
ssl_certificate=uds['ssl_certificate'],
ssl_certificate_key=uds.get('ssl_certificate_key', ''),

View File

@ -46,10 +46,12 @@ logger = logging.getLogger(__name__)
class Proxy:
cfg: 'config.ConfigurationType'
ns: 'Namespace'
finished: asyncio.Future
def __init__(self, cfg: 'config.ConfigurationType', ns: 'Namespace') -> None:
self.cfg = cfg
self.ns = ns
self.finished = asyncio.Future() # not done yet
# Method responsible of proxying requests
async def __call__(self, source: socket.socket, context: 'ssl.SSLContext') -> None:
@ -63,7 +65,7 @@ class Proxy:
addr = source.getpeername()
except Exception:
addr = 'Unknown'
logger.error('Proxy error from %s: %s', addr, e)
logger.exception('Proxy error from %s: %s (%s--%s)', addr, e, source, context)
async def proxy(self, source: socket.socket, context: 'ssl.SSLContext') -> None:
loop = asyncio.get_running_loop()
@ -72,13 +74,17 @@ class Proxy:
# Upgrade connection to SSL, and use asyncio to handle the rest
try:
protocol: tunnel.TunnelProtocol
def factory() -> tunnel.TunnelProtocol:
return tunnel.TunnelProtocol(self)
# (connect accepted loop not present on AbastractEventLoop definition < 3.10), that's why we use ignore
(_, protocol) = await loop.connect_accepted_socket( # type: ignore
lambda: tunnel.TunnelProtocol(self), source, ssl=context
await loop.connect_accepted_socket( # type: ignore
factory, source, ssl=context
)
await protocol.finished
# Wait for connection to be closed
await self.finished
except asyncio.CancelledError:
pass # Return on cancel

View File

@ -31,6 +31,7 @@ Author: Adolfo Gómez, dkmaster at dkmon dot com
import asyncio
import typing
import logging
import socket
import aiohttp
@ -46,8 +47,6 @@ if typing.TYPE_CHECKING:
# Protocol
class TunnelProtocol(asyncio.Protocol):
# future to mark eof
finished: asyncio.Future
# Transport and other side of tunnel
transport: 'asyncio.transports.Transport'
other_side: 'TunnelProtocol'
@ -86,7 +85,6 @@ class TunnelProtocol(asyncio.Protocol):
self.runner = self.do_command
# transport is undefined until connection_made is called
self.finished = asyncio.Future()
self.cmd = b''
self.notify_ticket = b''
self.owner = owner
@ -136,10 +134,12 @@ class TunnelProtocol(asyncio.Protocol):
)
try:
family = socket.AF_INET6 if ':' in self.destination[0] or self.owner.cfg.ipv6 else socket.AF_INET
(_, protocol) = await loop.create_connection(
lambda: TunnelProtocol(self.owner, self),
self.destination[0],
self.destination[1],
family=family,
)
self.other_side = typing.cast('TunnelProtocol', protocol)
@ -151,6 +151,7 @@ class TunnelProtocol(asyncio.Protocol):
logger.error('Error opening connection: %s', e)
self.close_connection()
# add open other side to the loop
loop.create_task(open_other_side())
# From now, proxy connection
self.runner = self.do_proxy
@ -280,7 +281,9 @@ class TunnelProtocol(asyncio.Protocol):
def connection_lost(self, exc: typing.Optional[Exception]) -> None:
logger.debug('Connection closed : %s', exc)
self.finished.set_result(True)
# notify end to parent proxy
self.owner.finished.set_result(True)
# Ensure close other side if any
if self.other_side is not self:
self.other_side.transport.close()
else:
@ -288,12 +291,17 @@ class TunnelProtocol(asyncio.Protocol):
self.notifyEnd()
# helpers
@staticmethod
def pretty_address(address: typing.Tuple[str, int]) -> str:
if ':' in address[0]:
return '[' + address[0] + ']:' + str(address[1])
return address[0] + ':' + str(address[1])
# source address, pretty format
def pretty_source(self) -> str:
return self.source[0] + ':' + str(self.source[1])
return TunnelProtocol.pretty_address(self.source)
def pretty_destination(self) -> str:
return self.destination[0] + ':' + str(self.destination[1])
return TunnelProtocol.pretty_address(self.destination)
def close_connection(self):
self.transport.close()

View File

@ -23,8 +23,9 @@ address = 0.0.0.0
# Listening port
port = 7777
# If force ipv6 listen, defaults to false
# If force ipv6, defaults to false
# Note: if listen address is an ipv6 address, this will be forced to true
# This will force dns resolution to ipv6
ipv6 = false
# Number of workers. Defaults to 0 (means "as much as cores")

View File

@ -77,26 +77,26 @@ def setup_log(cfg: config.ConfigurationType) -> None:
from logging.handlers import RotatingFileHandler
# Update logging if needed
if cfg.log_file:
if cfg.logfile:
fileh = RotatingFileHandler(
filename=cfg.log_file,
filename=cfg.logfile,
mode='a',
maxBytes=cfg.log_size,
backupCount=cfg.log_number,
maxBytes=cfg.logsize,
backupCount=cfg.lognumber,
)
formatter = logging.Formatter(consts.LOGFORMAT)
fileh.setFormatter(formatter)
log = logging.getLogger()
log.setLevel(cfg.log_level)
log.setLevel(cfg.loglevel)
# for hdlr in log.handlers[:]:
# log.removeHandler(hdlr)
log.addHandler(fileh)
else:
# Setup basic logging
log = logging.getLogger()
log.setLevel(cfg.log_level)
log.setLevel(cfg.loglevel)
handler = logging.StreamHandler(sys.stderr)
handler.setLevel(cfg.log_level)
handler.setLevel(cfg.loglevel)
formatter = logging.Formatter(
'%(levelname)s - %(message)s'
) # Basic log format, nice for syslog

View File

@ -43,25 +43,15 @@ class TestConfigFile(TestCase):
h = hashlib.sha256()
h.update(values.get('secret', '').encode())
secret = h.hexdigest()
# Adapt some values to config
values['secret'] = h.hexdigest()
values['allow'] = {values['allow']} # convert to set
values['logsize'] = values['logsize'] * 1024 * 1024
values['listen_address'] = values['address']
values['listen_port'] = values['port']
del values['address']
del values['port']
# Ensure data is correct
self.assertEqual(cfg.pidfile, values['pidfile'])
self.assertEqual(cfg.user, values['user'])
self.assertEqual(cfg.log_level, values['loglevel'])
self.assertEqual(cfg.log_file, values['logfile'])
self.assertEqual(
cfg.log_size, values['logsize'] * 1024 * 1024
) # Config file is in MB
self.assertEqual(cfg.log_number, values['lognumber'])
self.assertEqual(cfg.listen_address, values['address'])
self.assertEqual(cfg.workers, values['workers'])
self.assertEqual(cfg.ssl_certificate, values['ssl_certificate'])
self.assertEqual(cfg.ssl_certificate_key, values['ssl_certificate_key'])
self.assertEqual(cfg.ssl_ciphers, values['ssl_ciphers'])
self.assertEqual(cfg.ssl_dhparam, values['ssl_dhparam'])
self.assertEqual(cfg.uds_server, values['uds_server'])
self.assertEqual(cfg.uds_token, values['uds_token'])
self.assertEqual(cfg.uds_timeout, values['uds_timeout'])
self.assertEqual(cfg.secret, secret)
self.assertEqual(cfg.allow, {values['allow']})
self.assertEqual(cfg.uds_verify_ssl, values['uds_verify_ssl'])
for k, v in values.items():
self.assertEqual(getattr(cfg, k), v, f'Error in {k}')

View File

@ -44,6 +44,11 @@ logger = logging.getLogger(__name__)
class TestTunnel(IsolatedAsyncioTestCase):
async def asyncSetUp(self) -> None:
# Disable logging os slow tests
logging.disable(logging.WARNING)
return await super().asyncSetUp()
async def test_tunnel_invalid_command(self) -> None:
# Test invalid handshake
# data = b''

View File

@ -41,7 +41,6 @@ from .utils import fixtures
from .utils import tools, conf
logger = logging.getLogger(__name__)
logging.disable(logging.WARNING)
def uds_response(
_,
@ -56,6 +55,11 @@ def uds_response(
class TestTunnelHelpers(IsolatedAsyncioTestCase):
async def asyncSetUp(self) -> None:
# Disable logging os slow tests
logging.disable(logging.WARNING)
return await super().asyncSetUp()
async def test_get_ticket_from_uds_broker(self) -> None:
_, cfg = fixtures.get_config()
# Test some invalid tickets

View File

@ -41,27 +41,33 @@ from .utils import tuntools
logger = logging.getLogger(__name__)
class TestUDSTunnel(IsolatedAsyncioTestCase):
class TestUDSTunnelApp(IsolatedAsyncioTestCase):
async def asyncSetUp(self) -> None:
# Disable logging os slow tests
logging.disable(logging.WARNING)
return await super().asyncSetUp()
async def test_tunnel_fail_cmd_full(self) -> None:
async def test_tunnel_fail_cmd(self) -> None:
consts.TIMEOUT_COMMAND = 0.1 # type: ignore # timeout is a final variable, but we need to change it for testing speed
for i in range(0, 100, 10):
# Set timeout to 1 seconds
bad_cmd = bytes(random.randint(0, 255) for _ in range(i)) # Some garbage
logger.info(f'Testing invalid command with {bad_cmd!r}')
async with tuntools.create_tunnel_proc(
'127.0.0.1', 7777, '127.0.0.1', 12345, workers=1
) as cfg:
# On full, we need the handshake to be done, before connecting
async with tuntools.open_tunnel_client(cfg, use_tunnel_handshake=True) as (creader, cwriter):
cwriter.write(bad_cmd)
await cwriter.drain()
# Read response
data = await creader.read(1024)
# if len(bad_cmd) < consts.COMMAND_LENGTH, response will be RESPONSE_ERROR_TIMEOUT
if len(bad_cmd) >= consts.COMMAND_LENGTH:
self.assertEqual(data, consts.RESPONSE_ERROR_COMMAND)
else:
self.assertEqual(data, consts.RESPONSE_ERROR_TIMEOUT)
for host in ('127.0.0.1', '::1'):
# Remote is not really important in this tests, will fail before using it
async with tuntools.create_tunnel_proc(
host, 7777, '127.0.0.1', 12345, workers=1
) as cfg:
# On full, we need the handshake to be done, before connecting
async with tuntools.open_tunnel_client(cfg, use_tunnel_handshake=True) as (creader, cwriter):
cwriter.write(bad_cmd)
await cwriter.drain()
# Read response
data = await creader.read(1024)
# if len(bad_cmd) < consts.COMMAND_LENGTH, response will be RESPONSE_ERROR_TIMEOUT
if len(bad_cmd) >= consts.COMMAND_LENGTH:
self.assertEqual(data, consts.RESPONSE_ERROR_COMMAND)
else:
self.assertEqual(data, consts.RESPONSE_ERROR_TIMEOUT)

View File

@ -57,6 +57,9 @@ lognumber = {lognumber}
# Listen address. Defaults to 0.0.0.0
address = {address}
# If enforce ipv6. Defaults to False
ipv6 = {ipv6}
# Listen port. Defaults to 443
port = {port}
@ -93,7 +96,7 @@ allow = {allow}
use_uvloop = {use_uvloop}
'''
def get_config(**overrides) -> typing.Tuple[typing.Mapping[str, typing.Any], config.ConfigurationType]:
def get_config(**overrides) -> typing.Tuple[typing.Dict[str, typing.Any], config.ConfigurationType]:
rand_number = random.randint(0, 100)
values: typing.Dict[str, typing.Any] = {
'pidfile': f'/tmp/uds_tunnel_{random.randint(0, 100)}.pid', # Random pid file

View File

@ -75,6 +75,7 @@ async def create_tunnel_proc(
values, cfg = fixtures.get_config(
address=listen_host,
port=listen_port,
ipv6=':' in listen_host,
ssl_certificate=cert_file,
ssl_certificate_key='',
ssl_password=password,
@ -163,7 +164,7 @@ async def create_tunnel_server(
cfg.listen_port,
ssl=context,
family=socket.AF_INET6
if cfg.listen_ipv6 or ':' in cfg.listen_address
if cfg.ipv6 or ':' in cfg.listen_address
else socket.AF_INET,
)
@ -181,6 +182,7 @@ async def create_test_tunnel(
_, cfg = fixtures.get_config(
address=server.host,
port=7777,
ipv6=':' in server.host,
)
with mock.patch(
'uds_tunnel.tunnel.TunnelProtocol._readFromUDS',
@ -205,10 +207,7 @@ async def open_tunnel_client(
]:
"""opens an ssl socket to the tunnel server"""
loop = asyncio.get_running_loop()
if cfg.listen_ipv6 or ':' in cfg.listen_address:
family = socket.AF_INET6
else:
family = socket.AF_INET
family = socket.AF_INET6 if cfg.ipv6 or ':' in cfg.listen_address else socket.AF_INET
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
context.check_hostname = False
context.verify_mode = ssl.CERT_NONE