1
0
mirror of https://github.com/dkmstr/openuds.git synced 2025-03-11 00:58:39 +03:00

some fixes to make tests work as they are expected to

This commit is contained in:
Adolfo Gómez García 2023-05-21 16:19:58 +02:00
parent 99cd7030e0
commit c33c1501f5
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
10 changed files with 140 additions and 101 deletions

View File

@ -68,11 +68,12 @@ ignored-modules=
# Python code to execute, usually for sys.path manipulation such as
# pygtk.require().
#init-hook=
init-hook='import sys, os; sys.path.append(os.path.join(os.getcwd(), "src"))'
# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
# number of processors available to use, and will cap the count on Windows to
# avoid hangs.
jobs=1
jobs=4
# Control the amount of potential inferred values when inferring a single
# object. This can help the performance when dealing with large functions or

View File

@ -73,7 +73,7 @@ init-hook='import sys, os; sys.path.append(os.path.join(os.getcwd(), "src"))'
# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
# number of processors available to use, and will cap the count on Windows to
# avoid hangs.
jobs=1
jobs=4
# Control the amount of potential inferred values when inferring a single
# object. This can help the performance when dealing with large functions or

View File

@ -74,11 +74,17 @@ class Proxy:
# Handshake correct in this point, upgrade the connection to TSL and let
# the protocol controller do the rest
# Store source ip and port, for logging purposes in case of error
src_ip, src_port = (source.getpeername() if source else ('Unknown', 0))[:2] # May be ipv4 or ipv6, so we get only first two elements
# Upgrade connection to SSL, and use asyncio to handle the rest
tun: typing.Optional[tunnel.TunnelProtocol] = None
try:
tun = tunnel.TunnelProtocol(self)
# (connect accepted loop not present on AbastractEventLoop definition < 3.10), that's why we use ignore
await loop.connect_accepted_socket( # type: ignore
lambda: tunnel.TunnelProtocol(self), source, ssl=context
lambda: tun, source, ssl=context,
ssl_handshake_timeout=3,
)
# Wait for connection to be closed
@ -86,6 +92,11 @@ class Proxy:
except asyncio.CancelledError:
pass # Return on cancel
except Exception as e:
# Any other exception, ensure we close the connection
logger.error('ERROR on %s:%s: %s', src_ip, src_port, e)
if tun:
tun.close_connection()
logger.debug('Proxy finished')

View File

@ -97,8 +97,6 @@ class TunnelProtocol(asyncio.Protocol):
# We start processing command
# After command, we can process stats or do_proxy, that is the "normal" operation
self.runner = self.do_command
# Set starting timeout task, se we dont get hunged on connections without data (or insufficient data)
self.set_timeout(self.owner.cfg.command_timeout)
def process_open(self) -> None:
# Open Command has the ticket behind it
@ -158,7 +156,7 @@ class TunnelProtocol(asyncio.Protocol):
self.transport.write(b'OK')
self.stats_manager.increment_connections() # Increment connections counters
except Exception as e:
logger.error('Error opening connection: %s', e)
logger.error('CONNECTION FAILED: %s', e)
self.close_connection()
# add open other side to the loop
@ -276,9 +274,10 @@ class TunnelProtocol(asyncio.Protocol):
def close_connection(self):
try:
self.clean_timeout() # If a timeout is set, clean it
if not self.transport.is_closing():
self.transport.close()
except Exception:
except Exception: # nosec: best effort
pass # Ignore errors
def notify_end(self):
@ -304,6 +303,10 @@ class TunnelProtocol(asyncio.Protocol):
def connection_made(self, transport: 'asyncio.transports.BaseTransport') -> None:
# We know for sure that the transport is a Transport.
# Set starting timeout task, se we dont get hunged on connections without data (or insufficient data)
self.set_timeout(self.owner.cfg.command_timeout)
self.transport = typing.cast('asyncio.transports.Transport', transport)
# Get source
self.source = self.transport.get_extra_info('peername')

View File

@ -45,13 +45,14 @@ from concurrent.futures import ThreadPoolExecutor
try:
import uvloop
import uvloop # type: ignore
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
except ImportError:
pass # no uvloop support
try:
import setproctitle
import setproctitle # type: ignore
except ImportError:
setproctitle = None # type: ignore
@ -111,7 +112,12 @@ async def tunnel_proc_async(pipe: 'Connection', cfg: config.ConfigurationType, n
def add_autoremovable_task(task: asyncio.Task) -> None:
tasks.append(task)
task.add_done_callback(tasks.remove)
def remove_task(task: asyncio.Task) -> None:
logger.debug('Removing task %s', task)
tasks.remove(task)
task.add_done_callback(remove_task)
def get_socket() -> typing.Tuple[typing.Optional[socket.socket], typing.Optional[typing.Tuple[str, int]]]:
try:
@ -174,7 +180,9 @@ async def tunnel_proc_async(pipe: 'Connection', cfg: config.ConfigurationType, n
break # No more sockets, exit
logger.debug('CONNECTION from %s (pid: %s)', address, os.getpid())
# Due to proxy contains an "event" to stop, we need to create a new one for each connection
add_autoremovable_task(asyncio.create_task(proxy.Proxy(cfg, ns)(sock, context)))
add_autoremovable_task(
asyncio.create_task(proxy.Proxy(cfg, ns)(sock, context), name=f'proxy-{address}')
)
except asyncio.CancelledError: # pylint: disable=try-except-raise
raise # Stop, but avoid generic exception
except Exception:
@ -184,7 +192,7 @@ async def tunnel_proc_async(pipe: 'Connection', cfg: config.ConfigurationType, n
# create task for server
add_autoremovable_task(asyncio.create_task(run_server()))
add_autoremovable_task(asyncio.create_task(run_server(), name='server'))
try:
while tasks and not do_stop.is_set():
@ -195,6 +203,9 @@ async def tunnel_proc_async(pipe: 'Connection', cfg: config.ConfigurationType, n
except asyncio.CancelledError:
logger.info('Task cancelled')
do_stop.set() # ensure we stop
except Exception:
logger.exception('Error in main loop')
do_stop.set()
logger.debug('Out of loop, stopping tasks: %s, running: %s', tasks, do_stop.is_set())
@ -236,8 +247,10 @@ def tunnel_main(args: 'argparse.Namespace') -> None:
# Try to bind to port as running user
# Wait for socket incoming connections and spread them
socket.setdefaulttimeout(3.0) # So we can check for stop from time to time and not block forever
af_inet = socket.AF_INET6 if args.ipv6 or cfg.ipv6 or ':' in cfg.listen_address else socket.AF_INET
sock = socket.socket(af_inet, socket.SOCK_STREAM)
sock = socket.socket(
socket.AF_INET6 if args.ipv6 or cfg.ipv6 or ':' in cfg.listen_address else socket.AF_INET,
socket.SOCK_STREAM,
)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
# We will not reuse port, we only want a UDS tunnel server running on a port

View File

@ -45,7 +45,7 @@ logger = logging.getLogger(__name__)
class TestUDSTunnelApp(IsolatedAsyncioTestCase):
async def client_task(self, host: str, tunnel_port: int, remote_port: int) -> None:
async def client_task(self, host: str, tunnel_port: int, remote_port: int, use_tunnel_handshake: bool = False) -> None:
received: bytes = b''
callback_invoked: asyncio.Event = asyncio.Event()
# Data sent will be received by server
@ -77,7 +77,7 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
cfg.listen_port = tunnel_port
async with tuntools.open_tunnel_client(
cfg, local_port=remote_port + 10000, use_tunnel_handshake=True
cfg, local_port=remote_port + 10000, use_tunnel_handshake=use_tunnel_handshake
) as (
creader,
cwriter,
@ -120,7 +120,7 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
self.assertEqual(received, test_str)
async def test_app_concurrency(self) -> None:
concurrent_tasks = 512
concurrent_tasks = 1024
fake_broker_port = 20000
tunnel_server_port = fake_broker_port + 1
remote_port = fake_broker_port + 2
@ -158,7 +158,7 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
) as process: # pylint: disable=unused-variable
# Create a "bunch" of clients
tasks = [
asyncio.create_task(self.client_task(host, tunnel_server_port, remote_port + i))
asyncio.create_task(self.client_task(host, tunnel_server_port, remote_port + i, use_tunnel_handshake=True))
async for i in tools.waitable_range(concurrent_tasks)
]

View File

@ -62,7 +62,7 @@ class TestTunnel(IsolatedAsyncioTestCase):
# Commands are 4 bytes length, try with less and more invalid commands
for i in range(0, 100, 10):
# Set timeout to 1 seconds
bad_cmd = bytes(random.randint(0, 255) for _ in range(i)) # Some garbage
bad_cmd = bytes(random.randint(0, 255) for _ in range(i)) # nosec: Some garbage, not security related
logger.info('Testing invalid command with %s', bad_cmd)
async with tuntools.create_test_tunnel(callback=lambda x: None, port=7770, remote_port=54555, command_timeout=0.1) as cfg:
logger_mock = mock.MagicMock()
@ -95,18 +95,18 @@ class TestTunnel(IsolatedAsyncioTestCase):
def test_tunnel_invalid_handshake(self) -> None:
# Not async test, executed on main thread without event loop
# Pipe for testing
own_conn, other_conn = multiprocessing.Pipe()
own_conn, other_conn = multiprocessing.Pipe() # pylint: disable=unused-variable
# Some random data to send on each test, all invalid
# 0 bytes will make timeout to be reached
for i in [i for i in range(10)] + [i for i in range(100, 10000, 100)]:
for i in list(range(10)) + list(range(100, 10000, 100)):
# Create a simple socket for testing
rsock, wsock = socket.socketpair()
# Set read timeout to 1 seconds
rsock.settimeout(3)
# Set timeout to 1 seconds
bad_handshake = bytes(random.randint(0, 255) for _ in range(i))
bad_handshake = bytes(random.randint(0, 255) for _ in range(i)) # nosec not for security
logger_mock = mock.MagicMock()
with mock.patch('udstunnel.logger', logger_mock):
wsock.sendall(bad_handshake)
@ -138,4 +138,3 @@ class TestTunnel(IsolatedAsyncioTestCase):
# and that other_conn has received a ('host', 'port') tuple
# recv()[0] will be a copy of the socket, we don't care about it
self.assertEqual(other_conn.recv()[1], ('host', 'port'))

View File

@ -49,7 +49,6 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
async def test_run_app_help(self) -> None:
# Executes the app with --help
async with tuntools.tunnel_app_runner(args=['--help']) as process:
stdout, stderr = await process.communicate()
self.assertEqual(process.returncode, 0, f'{stdout!r} {stderr!r}')
self.assertEqual(stderr, b'')
@ -57,7 +56,7 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
async def test_tunnel_fail_cmd(self) -> None:
# Test on ipv4 and ipv6
for host in ('127.0.0.1', '::1'):
for host in ('::1', '127.0.0.1'):
# Remote is not really important in this tests, will fail before using it
async with tuntools.create_tunnel_proc(
host,
@ -65,18 +64,17 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
'127.0.0.1',
13579, # A port not used by any other test
command_timeout=0.1,
) as (cfg, queue):
) as (cfg, queue): # pylint: disable=unused-variable
for i in range(0, 8192, 128):
# Set timeout to 1 seconds
bad_cmd = bytes(
random.randint(0, 255) for _ in range(i)
) # Some garbage
logger.info(f'Testing invalid command with {bad_cmd!r}')
bad_cmd = bytes(random.randint(0, 255) for _ in range(i)) # nosec: Some garbage
logger.info('Testing invalid command with %s', bad_cmd)
# On full, we need the handshake to be done, before connecting
# Our "test" server will simple "eat" the handshake, but we need to do it
async with tuntools.open_tunnel_client(
cfg, use_tunnel_handshake=True
) as (creader, cwriter):
async with tuntools.open_tunnel_client(cfg) as (
creader,
cwriter,
):
cwriter.write(bad_cmd)
await cwriter.drain()
# Read response
@ -95,13 +93,14 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
7891,
'127.0.0.1',
13581,
) as (cfg, queue):
for i in range(10): # Several times
) as (cfg, queue): # pylint: disable=unused-variable
for _ in range(10): # Several times
# On full, we need the handshake to be done, before connecting
# Our "test" server will simple "eat" the handshake, but we need to do it
async with tuntools.open_tunnel_client(
cfg, use_tunnel_handshake=True
) as (creader, cwriter):
async with tuntools.open_tunnel_client(cfg) as (
creader,
cwriter,
):
cwriter.write(consts.COMMAND_TEST)
await cwriter.drain()
# Read response
@ -119,7 +118,7 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
server.host,
server.port,
command_timeout=0.1,
) as (cfg, queue):
) as (cfg, queue): # pylint: disable=unused-variable
for i in range(
0, consts.TICKET_LENGTH - 1, 4
): # All will fail. Any longer will be processed, and mock will return correct don't matter the ticket
@ -127,9 +126,10 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
ticket = tuntools.get_correct_ticket(i)
# On full, we need the handshake to be done, before connecting
# Our "test" server will simple "eat" the handshake, but we need to do it
async with tuntools.open_tunnel_client(
cfg, use_tunnel_handshake=True
) as (creader, cwriter):
async with tuntools.open_tunnel_client(cfg) as (
creader,
cwriter,
):
cwriter.write(consts.COMMAND_OPEN)
# fake ticket, consts.TICKET_LENGTH bytes long, letters and numbers. Use a random ticket,
cwriter.write(ticket)
@ -149,13 +149,11 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
received += data
# if data contains EOS marcker ('STREAM_END'), we are done
if b'STREAM_END' in data:
callback_invoked.set()
callback_invoked.set() # pylint: disable=cell-var-from-loop
# Remote is important in this tests
# create a remote server, use a different port than the tunnel fail test, because tests may run in parallel
async with tools.AsyncTCPServer(
host=host, port=5445, callback=callback
) as server:
async with tools.AsyncTCPServer(host=host, port=5445, callback=callback) as server:
for tunnel_host in ('127.0.0.1', '::1'):
async with tuntools.create_tunnel_proc(
tunnel_host,
@ -165,17 +163,19 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
use_fake_http_server=True,
) as (cfg, queue):
# Ensure queue is not none but an asyncio.Queue
# Note, this also let's mypy know that queue is not None after this point
if queue is None:
raise AssertionError('Queue is None')
for i in range(16):
for _ in range(16):
# Create a random ticket with valid format
ticket = tuntools.get_correct_ticket()
# On full, we need the handshake to be done, before connecting
# Our "test" server will simple "eat" the handshake, but we need to do it
async with tuntools.open_tunnel_client(
cfg, use_tunnel_handshake=True
) as (creader, cwriter):
async with tuntools.open_tunnel_client(cfg) as (
creader,
cwriter,
):
cwriter.write(consts.COMMAND_OPEN)
# fake ticket, consts.TICKET_LENGTH bytes long, letters and numbers. Use a random ticket,
cwriter.write(ticket)
@ -199,14 +199,14 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
else:
should_be_url = f'/stop/{cfg.uds_token}'.encode()
self.assertIn(should_be_url, queue_item)
# Ensure user agent is correct
# Data sent will be received by server
# One single write will ensure all data is on same packet
test_str = (
b'Some Random Data'
+ bytes(random.randint(0, 255) for _ in range(8192))
+ bytes(random.randint(0, 255) for _ in range(8192)) # nosec: some random data, not used for security
+ b'STREAM_END'
)
# Clean received data
@ -233,9 +233,10 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
ticket = tuntools.get_correct_ticket()
# On full, we need the handshake to be done, before connecting
# Our "test" server will simple "eat" the handshake, but we need to do it
async with tuntools.open_tunnel_client(
cfg, use_tunnel_handshake=True
) as (creader, cwriter):
async with tuntools.open_tunnel_client(cfg) as (
creader,
cwriter,
):
cwriter.write(consts.COMMAND_OPEN)
# fake ticket, consts.TICKET_LENGTH bytes long, letters and numbers. Use a random ticket,
cwriter.write(ticket)

View File

@ -32,18 +32,21 @@ import asyncio
import os
import ssl
import typing
import collections.abc
import socket
import aiohttp
import logging
from unittest import mock
import aiohttp
from . import certs
if typing.TYPE_CHECKING:
import collections.abc
logger = logging.getLogger(__name__)
class AsyncMock(mock.MagicMock):
async def __call__(self, *args, **kwargs):
async def __call__(self, *args, **kwargs): # pylint: disable=invalid-overridden-method
return super().__call__(*args, **kwargs)
@ -69,7 +72,7 @@ class AsyncHttpServer:
self._server = None
self._response = response
if use_ssl:
self._ssl_ctx, self._ssl_cert_file, pwd = certs.sslContext(host)
self._ssl_ctx, self._ssl_cert_file, pwd = certs.sslContext() # pylint: disable=unused-variable
else:
self._ssl_ctx = None
self._ssl_cert_file = None
@ -150,15 +153,18 @@ class AsyncTCPServer:
await self._processor(reader, writer)
return
while True:
data = await reader.read(4096) # Care with this and tunnel handshake on testings...
if not data:
break
try:
data = await reader.read(4096) # Care with this and tunnel handshake on testings...
if not data:
break
resp = self._callback(data) if self._callback else self._response
resp = self._callback(data) if self._callback else self._response
if resp is not None:
writer.write(resp)
await writer.drain()
if resp is not None:
writer.write(resp)
await writer.drain()
except Exception as e:
logger.exception('Exception %s on %s', e, self._name)
async def __aenter__(self) -> 'AsyncTCPServer':
if ':' in self.host:
@ -196,7 +202,7 @@ async def wait_for_port(host: str, port: int) -> None:
except ConnectionRefusedError:
await asyncio.sleep(0.1)
async def waitable_range(len: int, wait: float = 0.0001) -> 'collections.abc.AsyncGenerator[int, None]':
for i in range(len):
async def waitable_range(size: int, wait: float = 0.0001) -> 'collections.abc.AsyncGenerator[int, None]':
for i in range(size):
await asyncio.sleep(wait)
yield i

View File

@ -146,7 +146,7 @@ async def create_tunnel_proc(
if response is None:
response = conf.UDS_GET_TICKET_RESPONSE(remote_host, remote_port)
port = random.randint(8000, 58000) # nosec Just a random port
port = random.randint(20000, 40000) # nosec Just a random port
hhost = f'[{listen_host}]' if ':' in listen_host else listen_host
args = {
'uds_server': f'http://{hhost}:{port}/uds/rest',
@ -205,34 +205,37 @@ async def create_tunnel_proc(
# Create the tunnel task
task = asyncio.create_task(udstunnel.tunnel_proc_async(other_end, cfg, global_stats.ns))
# Create a small asyncio server that reads the handshake,
# and sends the socket to the tunnel_proc_async using the pipe
# the pipe message will be typing.Tuple[socket.socket, typing.Tuple[str, int]]
# socket and address
async def client_connected_cb(reader, writer):
# Read the handshake
# Note: We need a small wait on sender, because this is a bufferedReader
# so it will read the handshake and the first bytes of the data (that is the ssl handshake)
_ = await reader.read(len(consts.HANDSHAKE_V1))
# For testing, we ignore the handshake value
# Send the socket to the tunnel
own_end.send(
(
writer.get_extra_info('socket').dup(),
writer.get_extra_info('peername'),
)
)
# Close the socket
writer.close()
# Server listening for connections
server_socket = socket.socket(socket.AF_INET6 if ':' in listen_host else socket.AF_INET, socket.SOCK_STREAM)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) # Allow reuse of address
server_socket.bind((listen_host, listen_port))
server_socket.listen(8)
server_socket.setblocking(False)
server = await asyncio.start_server(
client_connected_cb,
listen_host,
listen_port,
)
async def server():
loop = asyncio.get_running_loop()
try:
while True:
client, addr = await loop.sock_accept(server_socket)
# Send the socket to the tunnel
own_end.send((client, addr))
except asyncio.CancelledError:
pass # We are closing
except Exception:
logger.exception('Exception in server')
# Close the socket
server_socket.close()
# Create the middleware task
server_task = asyncio.create_task(server())
try:
yield cfg, possible_queue
finally:
# Cancel the middleware task
server_task.cancel()
logger.info('Server closed')
# Close the pipe (both ends)
own_end.close()
@ -240,10 +243,6 @@ async def create_tunnel_proc(
# wait for the task to finish
await task
server.close()
await server.wait_closed()
logger.info('Server closed')
# Ensure log file are removed
rootlog = logging.getLogger()
for h in rootlog.handlers:
@ -385,7 +384,7 @@ async def open_tunnel_client(
if not use_tunnel_handshake:
reader, writer = await asyncio.open_connection(
cfg.listen_address, cfg.listen_port, ssl=context, family=family
cfg.listen_address, cfg.listen_port, ssl=context, family=family, ssl_handshake_timeout=1
)
else:
# Open the socket, send handshake and then upgrade to ssl, non blocking
@ -402,7 +401,6 @@ async def open_tunnel_client(
# (reads chunks of 4096 bytes). If we don't wait, the handshake will be readed
# and part or all of ssl handshake also.
# With uvloop this seems to be not needed, but with asyncio it is.
await asyncio.sleep(0.05)
# upgrade to ssl
reader, writer = await asyncio.open_connection(
sock=sock, ssl=context, server_hostname=cfg.listen_address
@ -451,8 +449,15 @@ async def tunnel_app_runner(
finally:
# Ensure the process is terminated
if process.returncode is None:
logger.info('Terminating tunnel process %s', process.pid)
process.terminate()
await process.wait()
await asyncio.wait_for(process.wait(), 10)
# Ensure the process is terminated
if process.returncode is None:
logger.info('Killing tunnel process %s', process.pid)
process.kill()
await asyncio.wait_for(process.wait(), 10)
logger.info('Tunnel process %s terminated', process.pid)
def get_correct_ticket(length: int = consts.TICKET_LENGTH, *, prefix: typing.Optional[str] = None) -> bytes: