1
0
mirror of https://github.com/dkmstr/openuds.git synced 2025-03-20 06:50:23 +03:00

final fixes for uds tunnel

This commit is contained in:
Adolfo Gómez García 2022-12-22 15:25:45 +01:00
parent 796d000f8e
commit f91aeb3ffd
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
12 changed files with 151 additions and 60 deletions

View File

@ -92,7 +92,6 @@ class TunnelProtocol(asyncio.Protocol):
logger.error('Invalid state reached!')
def connection_lost(self, exc: typing.Optional[Exception]) -> None:
logger.debug('Connection closed : %s', exc)
self.finished.set_result(True)
if self.other_side is not self:
self.other_side.transport.close()

View File

@ -66,11 +66,20 @@ class ConfigurationType(typing.NamedTuple):
uds_timeout: int
uds_verify_ssl: bool
command_timeout: float
secret: str
allow: typing.Set[str]
use_uvloop: bool
def __str__(self) -> str:
return 'Configuration: \n' + '\n'.join(
f'{k}={v}'
for k, v in self._asdict().items()
)
def read_config_file(
cfg_file: typing.Optional[typing.Union[typing.TextIO, str]] = None
@ -131,6 +140,7 @@ def read(
uds_token=uds.get('uds_token', 'unauthorized'),
uds_timeout=int(uds.get('uds_timeout', '10')),
uds_verify_ssl=uds.get('uds_verify_ssl', 'true').lower() == 'true',
command_timeout=float(uds.get('command_timeout', '3')),
secret=secret,
allow=set(uds.get('allow', '127.0.0.1').split(',')),
use_uvloop=uds.get('use_uvloop', 'true').lower() == 'true',

View File

@ -69,8 +69,5 @@ RESPONSE_FORBIDDEN: typing.Final[bytes] = b'FORBIDDEN'
RESPONSE_OK: typing.Final[bytes] = b'OK'
# Timeout for command
TIMEOUT_COMMAND: typing.Final[int] = 3
# Backlog for listen socket
BACKLOG = 1024

View File

@ -73,6 +73,13 @@ class TunnelProtocol(asyncio.Protocol):
) -> None:
# If no other side is given, we are the server part
super().__init__()
# transport is undefined until connection_made is called
self.cmd = b''
self.notify_ticket = b''
self.owner = owner
self.source = ('', 0)
self.destination = ('', 0)
if other_side:
self.other_side = other_side
self.stats_manager = other_side.stats_manager
@ -84,21 +91,15 @@ class TunnelProtocol(asyncio.Protocol):
self.counter = self.stats_manager.as_sent_counter()
self.runner = self.do_command
# Set starting timeout task, se we dont get hunged on connections without data
self.set_timeout(consts.TIMEOUT_COMMAND)
self.set_timeout(self.owner.cfg.command_timeout)
# transport is undefined until connection_made is called
self.cmd = b''
self.notify_ticket = b''
self.owner = owner
self.source = ('', 0)
self.destination = ('', 0)
def process_open(self) -> None:
# Open Command has the ticket behind it
if len(self.cmd) < consts.TICKET_LENGTH + consts.COMMAND_LENGTH:
# Reactivate timeout, will be deactivated on do_command
self.set_timeout(consts.TIMEOUT_COMMAND)
self.set_timeout(self.owner.cfg.command_timeout)
return # Wait for more data to complete OPEN command
# Ticket received, now process it with UDS
@ -196,7 +197,7 @@ class TunnelProtocol(asyncio.Protocol):
finally:
self.close_connection()
async def timeout(self, wait: int) -> None:
async def timeout(self, wait: float) -> None:
"""Timeout can only occur while waiting for a command (or OPEN command ticket)."""
try:
await asyncio.sleep(wait)
@ -206,7 +207,7 @@ class TunnelProtocol(asyncio.Protocol):
except asyncio.CancelledError:
pass
def set_timeout(self, wait: int) -> None:
def set_timeout(self, wait: float) -> None:
"""Set a timeout for this connection.
If reached, the connection will be closed.
@ -253,7 +254,7 @@ class TunnelProtocol(asyncio.Protocol):
self.close_connection()
return
else:
self.set_timeout(consts.TIMEOUT_COMMAND)
self.set_timeout(self.owner.cfg.command_timeout)
# if not enough data to process command, wait for more
@ -289,7 +290,6 @@ class TunnelProtocol(asyncio.Protocol):
self.owner.finished.set()
def connection_lost(self, exc: typing.Optional[Exception]) -> None:
logger.debug('Connection closed : %s', exc)
# Ensure close other side if any
if self.other_side is not self:
self.other_side.transport.close()

View File

@ -53,6 +53,10 @@ uds_token = eBCeFxTBw1IKXCqq-RlncshwWIfrrqxc8y5nehqiqMtRztwD
# If verify ssl certificate on uds server. Defaults to true
# uds_verify_ssl = true
# Command timeout. Command reception on tunnel will timeout after this time (in seconds)
# defaults to 3 seconds
# command_timeout = 3
# Secret to get access to admin commands (Currently only stats commands). No default for this.
# Admin commands and only allowed from "allow" ips
# So, in order to allow this commands, ensure listen address allows connections from localhost

View File

@ -102,6 +102,10 @@ def setup_log(cfg: config.ConfigurationType) -> None:
handler.setFormatter(formatter)
log.addHandler(handler)
# If debug, print config
if cfg.loglevel.lower() == 'debug':
logger.debug('Configuration: %s', cfg)
async def tunnel_proc_async(
pipe: 'Connection', cfg: config.ConfigurationType, ns: 'Namespace'
@ -111,6 +115,10 @@ async def tunnel_proc_async(
tasks: typing.List[asyncio.Task] = []
def add_autoremovable_task(task: asyncio.Task) -> None:
tasks.append(task)
task.add_done_callback(tasks.remove)
def get_socket() -> typing.Tuple[typing.Optional[socket.socket], typing.Optional[typing.Tuple[str, int]]]:
try:
while True:
@ -157,7 +165,7 @@ async def tunnel_proc_async(
break # No more sockets, exit
logger.debug(f'CONNECTION from {address!r} (pid: {os.getpid()})')
# Due to proxy contains an "event" to stop, we need to create a new one for each connection
tasks.append(asyncio.create_task(proxy.Proxy(cfg, ns)(sock, context)))
add_autoremovable_task(asyncio.create_task(proxy.Proxy(cfg, ns)(sock, context)))
except asyncio.CancelledError:
raise
except Exception:
@ -166,23 +174,20 @@ async def tunnel_proc_async(
pass # Stop
# create task for server
tasks.append(asyncio.create_task(run_server()))
add_autoremovable_task(asyncio.create_task(run_server()))
try:
while tasks and not do_stop.is_set():
to_wait = tasks[:] # Get a copy of the list
# Wait for "to_wait" tasks to finish, stop every 2 seconds to check if we need to stop
done, _ = await asyncio.wait(to_wait, return_when=asyncio.FIRST_COMPLETED, timeout=2)
# Remove finished tasks
for task in done:
tasks.remove(task)
if task.exception():
logger.exception('TUNNEL ERROR')
except asyncio.CancelledError:
logger.info('Task cancelled')
do_stop.set() # ensure we stop
logger.debug('Out of loop, stopping tasks: %s, running: %s', tasks, do_stop.is_set())
# If any task is still running, cancel it
for task in tasks:
task.cancel()

View File

@ -45,14 +45,6 @@ logger = logging.getLogger(__name__)
class TestUDSTunnelApp(IsolatedAsyncioTestCase):
async def test_run_app_help(self) -> None:
# Executes the app with --help
async with tuntools.tunnel_app_runner(args=['--help']) as process:
stdout, stderr = await process.communicate()
self.assertEqual(process.returncode, 0, f'{stdout!r} {stderr!r}')
self.assertEqual(stderr, b'')
self.assertIn(b'usage: udstunnel', stdout)
async def client_task(self, host: str, tunnel_port: int, remote_port: int) -> None:
received: bytes = b''
@ -118,8 +110,8 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
await callback_invoked.wait()
self.assertEqual(received, test_str)
async def test_run_app_serve(self) -> None:
concurrent_tasks = 256
async def test_app_concurrency(self) -> None:
concurrent_tasks = 512
fake_broker_port = 20000
tunnel_server_port = fake_broker_port + 1
remote_port = fake_broker_port + 2
@ -154,13 +146,15 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
logfile='/tmp/tunnel_test.log',
loglevel='DEBUG',
workers=4,
command_timeout=16, # Increase command timeout because heavy load we will create
) as process:
# Create a "bunch" of clients
tasks = [
asyncio.create_task(
self.client_task(host, tunnel_server_port, remote_port + i)
)
for i in range(concurrent_tasks)
async for i in tools.waitable_range(concurrent_tasks)
]
# Wait for all tasks to finish
@ -171,3 +165,52 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
task.result()
# Queue should have all requests (concurrent_tasks*2, one for open and one for close)
self.assertEqual(req_queue.qsize(), concurrent_tasks * 2)
async def test_tunnel_proc_concurrency(self) -> None:
concurrent_tasks = 512
fake_broker_port = 20000
tunnel_server_port = fake_broker_port + 1
remote_port = fake_broker_port + 2
# Extracts the port from an string that has bX0bwmbPORTbX0bwmb in it
req_queue: asyncio.Queue[bytes] = asyncio.Queue()
def extract_port(data: bytes) -> int:
req_queue.put_nowait(data)
if b'bX0bwmb' not in data:
return 12345 # No port, wil not be used because is an "stop" request
return int(data.split(b'bX0bwmb')[1])
for host in ('127.0.0.1', '::1'):
if ':' in host:
url = f'http://[{host}]:{fake_broker_port}/uds/rest'
else:
url = f'http://{host}:{fake_broker_port}/uds/rest'
req_queue = asyncio.Queue() # clear queue
# Use tunnel proc for testing
async with tuntools.create_tunnel_proc(
host,
tunnel_server_port,
response=lambda data: conf.UDS_GET_TICKET_RESPONSE(
host, extract_port(data)
),
command_timeout=16, # Increase command timeout because heavy load we will create
) as (cfg, _):
# Create a "bunch" of clients
tasks = [
asyncio.create_task(
self.client_task(host, tunnel_server_port, remote_port + i)
)
async for i in tools.waitable_range(concurrent_tasks)
]
# Wait for tasks to finish and check for exceptions
await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED)
# If any exception was raised, raise it
for task in tasks:
task.result()
# Queue should have all requests (concurrent_tasks*2, one for open and one for close)
self.assertEqual(req_queue.qsize(), concurrent_tasks * 2)

View File

@ -61,12 +61,11 @@ class TestTunnel(IsolatedAsyncioTestCase):
# Send invalid commands and see what happens
# Commands are 4 bytes length, try with less and more invalid commands
consts.TIMEOUT_COMMAND = 0.1 # type: ignore # timeout is a final variable, but we need to change it for testing speed
for i in range(0, 100, 10):
# Set timeout to 1 seconds
bad_cmd = bytes(random.randint(0, 255) for _ in range(i)) # Some garbage
logger.info(f'Testing invalid command with {bad_cmd!r}')
async with tuntools.create_test_tunnel(callback=lambda x: None, port=7770, remote_port=54555) as cfg:
async with tuntools.create_test_tunnel(callback=lambda x: None, port=7770, remote_port=54555, command_timeout=0.1) as cfg:
logger_mock = mock.MagicMock()
with mock.patch('uds_tunnel.tunnel.logger', logger_mock):
# Open connection to tunnel

View File

@ -31,13 +31,12 @@ Author: Adolfo Gómez, dkmaster at dkmon dot com
import typing
import random
import asyncio
import string
import logging
from unittest import IsolatedAsyncioTestCase, mock
from uds_tunnel import consts
from .utils import tuntools, tools
from .utils import tuntools, tools, conf
logger = logging.getLogger(__name__)
@ -48,8 +47,16 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
logging.disable(logging.WARNING)
return await super().asyncSetUp()
async def test_run_app_help(self) -> None:
# Executes the app with --help
async with tuntools.tunnel_app_runner(args=['--help']) as process:
stdout, stderr = await process.communicate()
self.assertEqual(process.returncode, 0, f'{stdout!r} {stderr!r}')
self.assertEqual(stderr, b'')
self.assertIn(b'usage: udstunnel', stdout)
async def test_tunnel_fail_cmd(self) -> None:
consts.TIMEOUT_COMMAND = 0.1 # type: ignore # timeout is a final variable, but we need to change it for testing speed
# Test on ipv4 and ipv6
for host in ('127.0.0.1', '::1'):
# Remote is not really important in this tests, will fail before using it
@ -58,6 +65,7 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
7890, # A port not used by any other test
'127.0.0.1',
13579, # A port not used by any other test
command_timeout=0.1,
) as (cfg, queue):
for i in range(0, 8192, 128):
# Set timeout to 1 seconds
@ -102,7 +110,6 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
self.assertEqual(data, consts.RESPONSE_OK)
async def test_tunnel_fail_open(self) -> None:
consts.TIMEOUT_COMMAND = 0.1 # type: ignore # timeout is a final variable, but we need to change it for testing speed
for host in ('127.0.0.1', '::1'):
# Remote is NOT important in this tests
# create a remote server
@ -112,6 +119,7 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
7775,
server.host,
server.port,
command_timeout=0.1,
) as (cfg, queue):
for i in range(
0, consts.TICKET_LENGTH - 1, 4
@ -152,7 +160,7 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
for tunnel_host in ('127.0.0.1', '::1'):
async with tuntools.create_tunnel_proc(
tunnel_host,
7778,
7778, # Not really used here
server.host,
server.port,
use_fake_http_server=True,

View File

@ -94,6 +94,10 @@ secret = {secret}
# defaults to localhost (change if listen address is different from 0.0.0.0)
allow = {allow}
# Command timeout. Command reception on tunnel will timeout after this time (in seconds)
# defaults to 3 seconds
command_timeout = {command_timeout}
use_uvloop = {use_uvloop}
'''
@ -121,6 +125,7 @@ def get_config(**overrides) -> typing.Tuple[typing.Dict[str, typing.Any], config
'uds_verify_ssl': random.choice([True, False]), # Random verify uds ssl
'secret': f'secret{random.randint(0, 100)}', # Random secret
'allow': f'{random.randint(0, 255)}.0.0.0', # Random allow
'command_timeout': random.randint(0, 100), # Random command timeout
'use_uvloop': random.choice([True, False]), # Random use uvloop
}
values.update(overrides)

View File

@ -195,3 +195,8 @@ async def wait_for_port(host: str, port: int) -> None:
return
except ConnectionRefusedError:
await asyncio.sleep(0.1)
async def waitable_range(len: int, wait: float = 0.0001) -> typing.AsyncGenerator[int, None]:
for i in range(len):
await asyncio.sleep(wait)
yield i

View File

@ -98,22 +98,29 @@ def create_config_file(
finally:
pass
# Remove the files if they exists
# for filename in (cfgfile, cert_file):
# try:
# os.remove(filename)
# except Exception:
# pass
for filename in (cfgfile, cert_file):
try:
os.remove(filename)
except Exception:
pass
@contextlib.asynccontextmanager
async def create_tunnel_proc(
listen_host: str,
listen_port: int,
remote_host: str,
remote_port: int,
remote_host: str = '0.0.0.0', # Not used if response is provided
remote_port: int = 0, # Not used if response is provided
*,
response: typing.Optional[typing.Mapping[str, typing.Any]] = None,
response: typing.Optional[
typing.Union[
typing.Callable[[bytes], typing.Mapping[str, typing.Any]],
typing.Mapping[str, typing.Any],
]
] = None,
use_fake_http_server: bool = False,
# Configuration parameters
**kwargs,
) -> typing.AsyncGenerator[
typing.Tuple['config.ConfigurationType', typing.Optional[asyncio.Queue[bytes]]],
None,
@ -126,7 +133,7 @@ async def create_tunnel_proc(
listen_port (int): Port to listen on
remote_host (str): Remote host to connect to
remote_port (int): Remote port to connect to
response (typing.Optional[typing.Mapping[str, typing.Any]], optional): Response to send to the tunnel. Defaults to None.
response (typing.Optional[typing.Union[typing.Callable[[bytes], typing.Mapping[str, typing.Any]], typing.Mapping[str, typing.Any]]], optional): Response to send to the client. Defaults to None.
use_fake_http_server (bool, optional): If True, a fake http server will be used instead of a mock. Defaults to False.
Yields:
@ -134,11 +141,16 @@ async def create_tunnel_proc(
and a queue with the data received by the "fake_http_server" if used, or None if not used
"""
# Ensure response
if response is None:
response = conf.UDS_GET_TICKET_RESPONSE(remote_host, remote_port)
port = random.randint(20000, 30000)
hhost = f'[{listen_host}]' if ':' in listen_host else listen_host
args = {
'uds_server': f'http://{hhost}:{port}/uds/rest',
}
args.update(kwargs) # Add extra args
# If use http server instead of mock
# We will setup a different context provider
if use_fake_http_server:
@ -149,7 +161,7 @@ async def create_tunnel_proc(
typing.Optional[asyncio.Queue[bytes]], None
]:
async with create_fake_broker_server(
listen_host, port, response=resp
listen_host, port, response=response or resp
) as queue:
try:
yield queue
@ -166,7 +178,10 @@ async def create_tunnel_proc(
'uds_tunnel.tunnel.TunnelProtocol._read_from_uds',
new_callable=tools.AsyncMock,
) as m:
m.return_value = response
if callable(response):
m.side_effect = lambda cfg, ticket, *args, **kwargs: response(ticket) # type: ignore
else:
m.return_value = response
try:
yield None
finally:
@ -181,10 +196,6 @@ async def create_tunnel_proc(
# Load config here also for testing
cfg = config.read(cfgfile)
# Ensure response
if response is None:
response = conf.UDS_GET_TICKET_RESPONSE(remote_host, remote_port)
async with provider() as possible_queue:
# Stats collector
gs = stats.GlobalStats()
@ -283,6 +294,8 @@ async def create_test_tunnel(
callback: typing.Callable[[bytes], None],
port: typing.Optional[int] = None,
remote_port: typing.Optional[int] = None,
# Configuration parameters
**kwargs: typing.Any,
) -> typing.AsyncGenerator['config.ConfigurationType', None]:
# Generate a listening server for testing tunnel
# Prepare the end of the tunnel
@ -296,6 +309,7 @@ async def create_test_tunnel(
address=server.host,
port=port or 7771,
ipv6=':' in server.host,
**kwargs,
)
with mock.patch(
'uds_tunnel.tunnel.TunnelProtocol._read_from_uds',
@ -316,9 +330,11 @@ async def create_fake_broker_server(
host: str,
port: int,
*,
response: typing.Union[
typing.Callable[[bytes], typing.Mapping[str, typing.Any]],
typing.Mapping[str, typing.Any],
response: typing.Optional[
typing.Union[
typing.Callable[[bytes], typing.Mapping[str, typing.Any]],
typing.Mapping[str, typing.Any],
]
],
) -> typing.AsyncGenerator[asyncio.Queue[bytes], None]:
# crate a fake broker server
@ -343,7 +359,7 @@ async def create_fake_broker_server(
if callable(response):
rr = response(data)
else:
rr = response
rr = response or {}
resp: bytes = (
b'HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n\r\n'