mirror of
https://github.com/dkmstr/openuds.git
synced 2025-03-13 08:58:35 +03:00
improved tunnel server and tunnel server tests
This commit is contained in:
parent
370799912f
commit
0f5f3df3f0
@ -126,9 +126,7 @@ async def main():
|
||||
data = client.recv(4)
|
||||
print(data)
|
||||
# Upgrade connection to SSL, and use asyncio to handle the rest
|
||||
transport: 'asyncio.transports.Transport'
|
||||
protocol: TunnelProtocol
|
||||
(transport, protocol) = await loop.connect_accepted_socket( # type: ignore
|
||||
(_, protocol) = await loop.connect_accepted_socket( # type: ignore
|
||||
lambda: TunnelProtocol(), client, ssl=context
|
||||
)
|
||||
|
||||
|
@ -43,14 +43,15 @@ class ConfigurationType(typing.NamedTuple):
|
||||
pidfile: str
|
||||
user: str
|
||||
|
||||
log_level: str
|
||||
log_file: str
|
||||
log_size: int
|
||||
log_number: int
|
||||
loglevel: str
|
||||
logfile: str
|
||||
logsize: int
|
||||
lognumber: int
|
||||
|
||||
listen_address: str
|
||||
listen_port: int
|
||||
listen_ipv6: bool
|
||||
|
||||
ipv6: bool
|
||||
|
||||
workers: int
|
||||
|
||||
@ -113,13 +114,13 @@ def read(
|
||||
return ConfigurationType(
|
||||
pidfile=uds.get('pidfile', ''),
|
||||
user=uds.get('user', ''),
|
||||
log_level=uds.get('loglevel', 'ERROR'),
|
||||
log_file=uds.get('logfile', ''),
|
||||
log_size=int(logsize) * 1024 * 1024,
|
||||
log_number=int(uds.get('lognumber', '3')),
|
||||
loglevel=uds.get('loglevel', 'ERROR'),
|
||||
logfile=uds.get('logfile', ''),
|
||||
logsize=int(logsize) * 1024 * 1024,
|
||||
lognumber=int(uds.get('lognumber', '3')),
|
||||
listen_address=uds.get('address', '0.0.0.0'),
|
||||
listen_port=int(uds.get('port', '443')),
|
||||
listen_ipv6=uds.get('ipv6', 'false').lower() == 'true',
|
||||
ipv6=uds.get('ipv6', 'false').lower() == 'true',
|
||||
workers=int(uds.get('workers', '0')) or multiprocessing.cpu_count(),
|
||||
ssl_certificate=uds['ssl_certificate'],
|
||||
ssl_certificate_key=uds.get('ssl_certificate_key', ''),
|
||||
|
@ -46,10 +46,12 @@ logger = logging.getLogger(__name__)
|
||||
class Proxy:
|
||||
cfg: 'config.ConfigurationType'
|
||||
ns: 'Namespace'
|
||||
finished: asyncio.Future
|
||||
|
||||
def __init__(self, cfg: 'config.ConfigurationType', ns: 'Namespace') -> None:
|
||||
self.cfg = cfg
|
||||
self.ns = ns
|
||||
self.finished = asyncio.Future() # not done yet
|
||||
|
||||
# Method responsible of proxying requests
|
||||
async def __call__(self, source: socket.socket, context: 'ssl.SSLContext') -> None:
|
||||
@ -63,7 +65,7 @@ class Proxy:
|
||||
addr = source.getpeername()
|
||||
except Exception:
|
||||
addr = 'Unknown'
|
||||
logger.error('Proxy error from %s: %s', addr, e)
|
||||
logger.exception('Proxy error from %s: %s (%s--%s)', addr, e, source, context)
|
||||
|
||||
async def proxy(self, source: socket.socket, context: 'ssl.SSLContext') -> None:
|
||||
loop = asyncio.get_running_loop()
|
||||
@ -72,13 +74,17 @@ class Proxy:
|
||||
|
||||
# Upgrade connection to SSL, and use asyncio to handle the rest
|
||||
try:
|
||||
protocol: tunnel.TunnelProtocol
|
||||
def factory() -> tunnel.TunnelProtocol:
|
||||
return tunnel.TunnelProtocol(self)
|
||||
# (connect accepted loop not present on AbastractEventLoop definition < 3.10), that's why we use ignore
|
||||
(_, protocol) = await loop.connect_accepted_socket( # type: ignore
|
||||
lambda: tunnel.TunnelProtocol(self), source, ssl=context
|
||||
await loop.connect_accepted_socket( # type: ignore
|
||||
factory, source, ssl=context
|
||||
)
|
||||
|
||||
await protocol.finished
|
||||
# Wait for connection to be closed
|
||||
await self.finished
|
||||
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass # Return on cancel
|
||||
|
||||
|
@ -31,6 +31,7 @@ Author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
import asyncio
|
||||
import typing
|
||||
import logging
|
||||
import socket
|
||||
|
||||
import aiohttp
|
||||
|
||||
@ -46,8 +47,6 @@ if typing.TYPE_CHECKING:
|
||||
|
||||
# Protocol
|
||||
class TunnelProtocol(asyncio.Protocol):
|
||||
# future to mark eof
|
||||
finished: asyncio.Future
|
||||
# Transport and other side of tunnel
|
||||
transport: 'asyncio.transports.Transport'
|
||||
other_side: 'TunnelProtocol'
|
||||
@ -86,7 +85,6 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
self.runner = self.do_command
|
||||
|
||||
# transport is undefined until connection_made is called
|
||||
self.finished = asyncio.Future()
|
||||
self.cmd = b''
|
||||
self.notify_ticket = b''
|
||||
self.owner = owner
|
||||
@ -136,10 +134,12 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
)
|
||||
|
||||
try:
|
||||
family = socket.AF_INET6 if ':' in self.destination[0] or self.owner.cfg.ipv6 else socket.AF_INET
|
||||
(_, protocol) = await loop.create_connection(
|
||||
lambda: TunnelProtocol(self.owner, self),
|
||||
self.destination[0],
|
||||
self.destination[1],
|
||||
family=family,
|
||||
)
|
||||
self.other_side = typing.cast('TunnelProtocol', protocol)
|
||||
|
||||
@ -151,6 +151,7 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
logger.error('Error opening connection: %s', e)
|
||||
self.close_connection()
|
||||
|
||||
# add open other side to the loop
|
||||
loop.create_task(open_other_side())
|
||||
# From now, proxy connection
|
||||
self.runner = self.do_proxy
|
||||
@ -280,7 +281,9 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
|
||||
def connection_lost(self, exc: typing.Optional[Exception]) -> None:
|
||||
logger.debug('Connection closed : %s', exc)
|
||||
self.finished.set_result(True)
|
||||
# notify end to parent proxy
|
||||
self.owner.finished.set_result(True)
|
||||
# Ensure close other side if any
|
||||
if self.other_side is not self:
|
||||
self.other_side.transport.close()
|
||||
else:
|
||||
@ -288,12 +291,17 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
self.notifyEnd()
|
||||
|
||||
# helpers
|
||||
@staticmethod
|
||||
def pretty_address(address: typing.Tuple[str, int]) -> str:
|
||||
if ':' in address[0]:
|
||||
return '[' + address[0] + ']:' + str(address[1])
|
||||
return address[0] + ':' + str(address[1])
|
||||
# source address, pretty format
|
||||
def pretty_source(self) -> str:
|
||||
return self.source[0] + ':' + str(self.source[1])
|
||||
return TunnelProtocol.pretty_address(self.source)
|
||||
|
||||
def pretty_destination(self) -> str:
|
||||
return self.destination[0] + ':' + str(self.destination[1])
|
||||
return TunnelProtocol.pretty_address(self.destination)
|
||||
|
||||
def close_connection(self):
|
||||
self.transport.close()
|
||||
|
@ -23,8 +23,9 @@ address = 0.0.0.0
|
||||
# Listening port
|
||||
port = 7777
|
||||
|
||||
# If force ipv6 listen, defaults to false
|
||||
# If force ipv6, defaults to false
|
||||
# Note: if listen address is an ipv6 address, this will be forced to true
|
||||
# This will force dns resolution to ipv6
|
||||
ipv6 = false
|
||||
|
||||
# Number of workers. Defaults to 0 (means "as much as cores")
|
||||
|
@ -77,26 +77,26 @@ def setup_log(cfg: config.ConfigurationType) -> None:
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
# Update logging if needed
|
||||
if cfg.log_file:
|
||||
if cfg.logfile:
|
||||
fileh = RotatingFileHandler(
|
||||
filename=cfg.log_file,
|
||||
filename=cfg.logfile,
|
||||
mode='a',
|
||||
maxBytes=cfg.log_size,
|
||||
backupCount=cfg.log_number,
|
||||
maxBytes=cfg.logsize,
|
||||
backupCount=cfg.lognumber,
|
||||
)
|
||||
formatter = logging.Formatter(consts.LOGFORMAT)
|
||||
fileh.setFormatter(formatter)
|
||||
log = logging.getLogger()
|
||||
log.setLevel(cfg.log_level)
|
||||
log.setLevel(cfg.loglevel)
|
||||
# for hdlr in log.handlers[:]:
|
||||
# log.removeHandler(hdlr)
|
||||
log.addHandler(fileh)
|
||||
else:
|
||||
# Setup basic logging
|
||||
log = logging.getLogger()
|
||||
log.setLevel(cfg.log_level)
|
||||
log.setLevel(cfg.loglevel)
|
||||
handler = logging.StreamHandler(sys.stderr)
|
||||
handler.setLevel(cfg.log_level)
|
||||
handler.setLevel(cfg.loglevel)
|
||||
formatter = logging.Formatter(
|
||||
'%(levelname)s - %(message)s'
|
||||
) # Basic log format, nice for syslog
|
||||
|
@ -43,25 +43,15 @@ class TestConfigFile(TestCase):
|
||||
|
||||
h = hashlib.sha256()
|
||||
h.update(values.get('secret', '').encode())
|
||||
secret = h.hexdigest()
|
||||
# Adapt some values to config
|
||||
values['secret'] = h.hexdigest()
|
||||
values['allow'] = {values['allow']} # convert to set
|
||||
values['logsize'] = values['logsize'] * 1024 * 1024
|
||||
values['listen_address'] = values['address']
|
||||
values['listen_port'] = values['port']
|
||||
|
||||
del values['address']
|
||||
del values['port']
|
||||
# Ensure data is correct
|
||||
self.assertEqual(cfg.pidfile, values['pidfile'])
|
||||
self.assertEqual(cfg.user, values['user'])
|
||||
self.assertEqual(cfg.log_level, values['loglevel'])
|
||||
self.assertEqual(cfg.log_file, values['logfile'])
|
||||
self.assertEqual(
|
||||
cfg.log_size, values['logsize'] * 1024 * 1024
|
||||
) # Config file is in MB
|
||||
self.assertEqual(cfg.log_number, values['lognumber'])
|
||||
self.assertEqual(cfg.listen_address, values['address'])
|
||||
self.assertEqual(cfg.workers, values['workers'])
|
||||
self.assertEqual(cfg.ssl_certificate, values['ssl_certificate'])
|
||||
self.assertEqual(cfg.ssl_certificate_key, values['ssl_certificate_key'])
|
||||
self.assertEqual(cfg.ssl_ciphers, values['ssl_ciphers'])
|
||||
self.assertEqual(cfg.ssl_dhparam, values['ssl_dhparam'])
|
||||
self.assertEqual(cfg.uds_server, values['uds_server'])
|
||||
self.assertEqual(cfg.uds_token, values['uds_token'])
|
||||
self.assertEqual(cfg.uds_timeout, values['uds_timeout'])
|
||||
self.assertEqual(cfg.secret, secret)
|
||||
self.assertEqual(cfg.allow, {values['allow']})
|
||||
self.assertEqual(cfg.uds_verify_ssl, values['uds_verify_ssl'])
|
||||
for k, v in values.items():
|
||||
self.assertEqual(getattr(cfg, k), v, f'Error in {k}')
|
||||
|
@ -44,6 +44,11 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TestTunnel(IsolatedAsyncioTestCase):
|
||||
async def asyncSetUp(self) -> None:
|
||||
# Disable logging os slow tests
|
||||
logging.disable(logging.WARNING)
|
||||
return await super().asyncSetUp()
|
||||
|
||||
async def test_tunnel_invalid_command(self) -> None:
|
||||
# Test invalid handshake
|
||||
# data = b''
|
||||
|
@ -41,7 +41,6 @@ from .utils import fixtures
|
||||
from .utils import tools, conf
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.disable(logging.WARNING)
|
||||
|
||||
def uds_response(
|
||||
_,
|
||||
@ -56,6 +55,11 @@ def uds_response(
|
||||
|
||||
|
||||
class TestTunnelHelpers(IsolatedAsyncioTestCase):
|
||||
async def asyncSetUp(self) -> None:
|
||||
# Disable logging os slow tests
|
||||
logging.disable(logging.WARNING)
|
||||
return await super().asyncSetUp()
|
||||
|
||||
async def test_get_ticket_from_uds_broker(self) -> None:
|
||||
_, cfg = fixtures.get_config()
|
||||
# Test some invalid tickets
|
||||
|
@ -41,27 +41,33 @@ from .utils import tuntools
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TestUDSTunnel(IsolatedAsyncioTestCase):
|
||||
class TestUDSTunnelApp(IsolatedAsyncioTestCase):
|
||||
async def asyncSetUp(self) -> None:
|
||||
# Disable logging os slow tests
|
||||
logging.disable(logging.WARNING)
|
||||
return await super().asyncSetUp()
|
||||
|
||||
async def test_tunnel_fail_cmd_full(self) -> None:
|
||||
async def test_tunnel_fail_cmd(self) -> None:
|
||||
consts.TIMEOUT_COMMAND = 0.1 # type: ignore # timeout is a final variable, but we need to change it for testing speed
|
||||
for i in range(0, 100, 10):
|
||||
# Set timeout to 1 seconds
|
||||
bad_cmd = bytes(random.randint(0, 255) for _ in range(i)) # Some garbage
|
||||
logger.info(f'Testing invalid command with {bad_cmd!r}')
|
||||
async with tuntools.create_tunnel_proc(
|
||||
'127.0.0.1', 7777, '127.0.0.1', 12345, workers=1
|
||||
) as cfg:
|
||||
# On full, we need the handshake to be done, before connecting
|
||||
async with tuntools.open_tunnel_client(cfg, use_tunnel_handshake=True) as (creader, cwriter):
|
||||
cwriter.write(bad_cmd)
|
||||
await cwriter.drain()
|
||||
# Read response
|
||||
data = await creader.read(1024)
|
||||
# if len(bad_cmd) < consts.COMMAND_LENGTH, response will be RESPONSE_ERROR_TIMEOUT
|
||||
if len(bad_cmd) >= consts.COMMAND_LENGTH:
|
||||
self.assertEqual(data, consts.RESPONSE_ERROR_COMMAND)
|
||||
else:
|
||||
self.assertEqual(data, consts.RESPONSE_ERROR_TIMEOUT)
|
||||
|
||||
for host in ('127.0.0.1', '::1'):
|
||||
# Remote is not really important in this tests, will fail before using it
|
||||
async with tuntools.create_tunnel_proc(
|
||||
host, 7777, '127.0.0.1', 12345, workers=1
|
||||
) as cfg:
|
||||
# On full, we need the handshake to be done, before connecting
|
||||
async with tuntools.open_tunnel_client(cfg, use_tunnel_handshake=True) as (creader, cwriter):
|
||||
cwriter.write(bad_cmd)
|
||||
await cwriter.drain()
|
||||
# Read response
|
||||
data = await creader.read(1024)
|
||||
# if len(bad_cmd) < consts.COMMAND_LENGTH, response will be RESPONSE_ERROR_TIMEOUT
|
||||
if len(bad_cmd) >= consts.COMMAND_LENGTH:
|
||||
self.assertEqual(data, consts.RESPONSE_ERROR_COMMAND)
|
||||
else:
|
||||
self.assertEqual(data, consts.RESPONSE_ERROR_TIMEOUT)
|
||||
|
||||
|
||||
|
@ -57,6 +57,9 @@ lognumber = {lognumber}
|
||||
# Listen address. Defaults to 0.0.0.0
|
||||
address = {address}
|
||||
|
||||
# If enforce ipv6. Defaults to False
|
||||
ipv6 = {ipv6}
|
||||
|
||||
# Listen port. Defaults to 443
|
||||
port = {port}
|
||||
|
||||
@ -93,7 +96,7 @@ allow = {allow}
|
||||
use_uvloop = {use_uvloop}
|
||||
'''
|
||||
|
||||
def get_config(**overrides) -> typing.Tuple[typing.Mapping[str, typing.Any], config.ConfigurationType]:
|
||||
def get_config(**overrides) -> typing.Tuple[typing.Dict[str, typing.Any], config.ConfigurationType]:
|
||||
rand_number = random.randint(0, 100)
|
||||
values: typing.Dict[str, typing.Any] = {
|
||||
'pidfile': f'/tmp/uds_tunnel_{random.randint(0, 100)}.pid', # Random pid file
|
||||
|
@ -75,6 +75,7 @@ async def create_tunnel_proc(
|
||||
values, cfg = fixtures.get_config(
|
||||
address=listen_host,
|
||||
port=listen_port,
|
||||
ipv6=':' in listen_host,
|
||||
ssl_certificate=cert_file,
|
||||
ssl_certificate_key='',
|
||||
ssl_password=password,
|
||||
@ -163,7 +164,7 @@ async def create_tunnel_server(
|
||||
cfg.listen_port,
|
||||
ssl=context,
|
||||
family=socket.AF_INET6
|
||||
if cfg.listen_ipv6 or ':' in cfg.listen_address
|
||||
if cfg.ipv6 or ':' in cfg.listen_address
|
||||
else socket.AF_INET,
|
||||
)
|
||||
|
||||
@ -181,6 +182,7 @@ async def create_test_tunnel(
|
||||
_, cfg = fixtures.get_config(
|
||||
address=server.host,
|
||||
port=7777,
|
||||
ipv6=':' in server.host,
|
||||
)
|
||||
with mock.patch(
|
||||
'uds_tunnel.tunnel.TunnelProtocol._readFromUDS',
|
||||
@ -205,10 +207,7 @@ async def open_tunnel_client(
|
||||
]:
|
||||
"""opens an ssl socket to the tunnel server"""
|
||||
loop = asyncio.get_running_loop()
|
||||
if cfg.listen_ipv6 or ':' in cfg.listen_address:
|
||||
family = socket.AF_INET6
|
||||
else:
|
||||
family = socket.AF_INET
|
||||
family = socket.AF_INET6 if cfg.ipv6 or ':' in cfg.listen_address else socket.AF_INET
|
||||
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
context.check_hostname = False
|
||||
context.verify_mode = ssl.CERT_NONE
|
||||
|
Loading…
x
Reference in New Issue
Block a user