1
0
mirror of https://github.com/dkmstr/openuds.git synced 2025-01-18 06:03:54 +03:00

Fixes tests and some linting issues.

This commit is contained in:
Adolfo Gómez García 2023-05-21 05:35:52 +02:00
parent cb76078758
commit 379ce8a094
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
10 changed files with 42 additions and 44 deletions

View File

@ -67,7 +67,8 @@ ignored-modules=
# Python code to execute, usually for sys.path manipulation such as
# pygtk.require().
#init-hook=
# So tests linting can find the src directory, where the code is...
init-hook='import sys, os; sys.path.append(os.path.join(os.getcwd(), "src"))'
# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
# number of processors available to use, and will cap the count on Windows to

View File

@ -1,3 +1,4 @@
psutil>=5.7.3
cryptography
aiohttp
uvloop

View File

@ -46,7 +46,6 @@ from concurrent.futures import ThreadPoolExecutor
try:
import uvloop
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
except ImportError:
pass # no uvloop support
@ -318,7 +317,7 @@ def tunnel_main(args: 'argparse.Namespace') -> None:
if cfg.pidfile:
os.unlink(cfg.pidfile)
except Exception:
pass
logger.warning('Could not remove pidfile %s', cfg.pidfile)
logger.info('FINISHED')

View File

@ -43,6 +43,7 @@ if typing.TYPE_CHECKING:
logger = logging.getLogger(__name__)
class TestUDSTunnelApp(IsolatedAsyncioTestCase):
async def client_task(self, host: str, tunnel_port: int, remote_port: int) -> None:
received: bytes = b''
@ -50,13 +51,9 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
# Data sent will be received by server
# One single write will ensure all data is on same packet
test_str = (
b'Some Random Data'
+ bytes(random.randint(0, 255) for _ in range(1024)) * 4
+ b'STREAM_END'
b'Some Random Data' + bytes(random.randint(0, 255) for _ in range(1024)) * 4 + b'STREAM_END' # nosec: just testing data
) # length = 16 + 1024 * 4 + 10 = 4122
test_response = (
bytes(random.randint(48, 127) for _ in range(12))
) # length = 12, random printable chars
test_response = bytes(random.randint(48, 127) for _ in range(12)) # nosec: length = 12, random printable chars
def callback(data: bytes) -> typing.Optional[bytes]:
nonlocal received
@ -127,6 +124,7 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
fake_broker_port = 20000
tunnel_server_port = fake_broker_port + 1
remote_port = fake_broker_port + 2
# Extracts the port from an string that has bX0bwmbPORTbX0bwmb in it
def extract_port(data: bytes) -> int:
if b'bX0bwmb' not in data:
@ -142,9 +140,7 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
async with tuntools.create_fake_broker_server(
host,
fake_broker_port,
response=lambda data: conf.UDS_GET_TICKET_RESPONSE(
host, extract_port(data)
),
response=lambda data: conf.UDS_GET_TICKET_RESPONSE(host, extract_port(data)),
) as req_queue:
if req_queue is None:
raise AssertionError('req_queue is None')
@ -160,12 +156,9 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
workers=4,
command_timeout=16, # Increase command timeout because heavy load we will create
) as process:
# Create a "bunch" of clients
tasks = [
asyncio.create_task(
self.client_task(host, tunnel_server_port, remote_port + i)
)
asyncio.create_task(self.client_task(host, tunnel_server_port, remote_port + i))
async for i in tools.waitable_range(concurrent_tasks)
]
@ -185,7 +178,7 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
remote_port = fake_broker_port + 2
# Extracts the port from an string that has bX0bwmbPORTbX0bwmb in it
req_queue: asyncio.Queue[bytes] = asyncio.Queue()
req_queue: 'asyncio.Queue[bytes]' = asyncio.Queue()
def extract_port(data: bytes) -> int:
logger.debug('Data: %r', data)
@ -206,17 +199,13 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
async with tuntools.create_tunnel_proc(
host,
tunnel_server_port,
response=lambda data: conf.UDS_GET_TICKET_RESPONSE(
host, extract_port(data)
),
response=lambda data: conf.UDS_GET_TICKET_RESPONSE(host, extract_port(data)),
command_timeout=16, # Increase command timeout because heavy load we will create,
global_stats=stats_collector,
) as (cfg, _):
# Create a "bunch" of clients
tasks = [
asyncio.create_task(
self.client_task(host, tunnel_server_port, remote_port + i)
)
asyncio.create_task(self.client_task(host, tunnel_server_port, remote_port + i))
async for i in tools.waitable_range(concurrent_tasks)
]
@ -228,8 +217,8 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
# Queue should have all requests (concurrent_tasks*2, one for open and one for close)
self.assertEqual(req_queue.qsize(), concurrent_tasks * 2)
# Check stats
self.assertEqual(stats_collector.ns.recv, concurrent_tasks*12)
self.assertEqual(stats_collector.ns.sent, concurrent_tasks*4122)
self.assertEqual(stats_collector.ns.recv, concurrent_tasks * 12)
self.assertEqual(stats_collector.ns.sent, concurrent_tasks * 4122)
self.assertEqual(stats_collector.ns.total, concurrent_tasks)

View File

@ -49,7 +49,7 @@ class TestConfigFile(TestCase):
values['logsize'] = values['logsize'] * 1024 * 1024
values['listen_address'] = values['address']
values['listen_port'] = values['port']
del values['address']
del values['port']
# Ensure data is correct

View File

@ -28,7 +28,6 @@
'''
Author: Adolfo Gómez, dkmaster at dkmon dot com
'''
import typing
import random
import socket
import logging
@ -36,7 +35,7 @@ import multiprocessing
from unittest import IsolatedAsyncioTestCase, mock
from udstunnel import process_connection
from uds_tunnel import tunnel, consts
from uds_tunnel import consts
from .utils import tuntools
@ -64,7 +63,7 @@ class TestTunnel(IsolatedAsyncioTestCase):
for i in range(0, 100, 10):
# Set timeout to 1 seconds
bad_cmd = bytes(random.randint(0, 255) for _ in range(i)) # Some garbage
logger.info(f'Testing invalid command with {bad_cmd!r}')
logger.info('Testing invalid command with %s', bad_cmd)
async with tuntools.create_test_tunnel(callback=lambda x: None, port=7770, remote_port=54555, command_timeout=0.1) as cfg:
logger_mock = mock.MagicMock()
with mock.patch('uds_tunnel.tunnel.logger', logger_mock):

View File

@ -28,11 +28,10 @@
'''
Author: Adolfo Gómez, dkmaster at dkmon dot com
'''
import typing
import random
import asyncio
import logging
from unittest import IsolatedAsyncioTestCase, mock
from unittest import IsolatedAsyncioTestCase
from uds_tunnel import consts

View File

@ -107,7 +107,7 @@ def sslContext(ip: str) -> typing.Tuple[ssl.SSLContext, str, str]:
f.write(cert)
# Create SSL context
ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
ssl_ctx.options |= ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1
ssl_ctx.minimum_version = ssl.TLSVersion.TLSv1_2
ssl_ctx.load_cert_chain(certfile=f'{tmpdir}/{tmpname}.pem', password=password)
ssl_ctx.check_hostname = False
ssl_ctx.set_ciphers('ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384')

View File

@ -32,6 +32,7 @@ import asyncio
import os
import ssl
import typing
import collections.abc
import socket
import aiohttp
import logging
@ -149,7 +150,7 @@ class AsyncTCPServer:
await self._processor(reader, writer)
return
while True:
data = await reader.read(4096)
data = await reader.read(4096) # Care with this and tunnel handshake on testings...
if not data:
break
@ -195,7 +196,7 @@ async def wait_for_port(host: str, port: int) -> None:
except ConnectionRefusedError:
await asyncio.sleep(0.1)
async def waitable_range(len: int, wait: float = 0.0001) -> typing.AsyncGenerator[int, None]:
async def waitable_range(len: int, wait: float = 0.0001) -> 'collections.abc.AsyncGenerator[int, None]':
for i in range(len):
await asyncio.sleep(wait)
yield i

View File

@ -30,6 +30,7 @@ Author: Adolfo Gómez, dkmaster at dkmon dot com
'''
import asyncio
import contextlib
import collections.abc
import json
import logging
import multiprocessing
@ -122,7 +123,7 @@ async def create_tunnel_proc(
global_stats: typing.Optional[stats.GlobalStats] = None,
# Configuration parameters
**kwargs,
) -> typing.AsyncGenerator[
) -> collections.abc.AsyncGenerator[
typing.Tuple['config.ConfigurationType', typing.Optional[asyncio.Queue[bytes]]],
None,
]:
@ -138,7 +139,7 @@ async def create_tunnel_proc(
use_fake_http_server (bool, optional): If True, a fake http server will be used instead of a mock. Defaults to False.
Yields:
typing.AsyncGenerator[typing.Tuple[config.ConfigurationType, typing.Optional[asyncio.Queue[bytes]]], None]: A tuple with the configuration
collections.abc.AsyncGenerator[typing.Tuple[config.ConfigurationType, typing.Optional[asyncio.Queue[bytes]]], None]: A tuple with the configuration
and a queue with the data received by the "fake_http_server" if used, or None if not used
"""
@ -158,7 +159,7 @@ async def create_tunnel_proc(
resp = conf.UDS_GET_TICKET_RESPONSE(remote_host, remote_port)
@contextlib.asynccontextmanager
async def provider() -> typing.AsyncGenerator[
async def provider() -> collections.abc.AsyncGenerator[
typing.Optional[asyncio.Queue[bytes]], None
]:
async with create_fake_broker_server(
@ -172,7 +173,7 @@ async def create_tunnel_proc(
else:
@contextlib.asynccontextmanager
async def provider() -> typing.AsyncGenerator[
async def provider() -> collections.abc.AsyncGenerator[
typing.Optional[asyncio.Queue[bytes]], None
]:
with mock.patch(
@ -219,7 +220,9 @@ async def create_tunnel_proc(
# socket and address
async def client_connected_cb(reader, writer):
# Read the handshake
data = await reader.read(1024)
# Note: We need a small wait on sender, because this is a bufferedReader
# so it will read the handshake and the first bytes of the data (that is the ssl handshake)
_ = await reader.read(len(consts.HANDSHAKE_V1))
# For testing, we ignore the handshake value
# Send the socket to the tunnel
own_end.send(
@ -297,7 +300,7 @@ async def create_test_tunnel(
remote_port: typing.Optional[int] = None,
# Configuration parameters
**kwargs: typing.Any,
) -> typing.AsyncGenerator['config.ConfigurationType', None]:
) -> collections.abc.AsyncGenerator['config.ConfigurationType', None]:
# Generate a listening server for testing tunnel
# Prepare the end of the tunnel
async with tools.AsyncTCPServer(
@ -337,7 +340,7 @@ async def create_fake_broker_server(
typing.Mapping[str, typing.Any],
]
],
) -> typing.AsyncGenerator[asyncio.Queue[bytes], None]:
) -> collections.abc.AsyncGenerator[asyncio.Queue[bytes], None]:
# crate a fake broker server
# Ignores request, and sends response
# if is a callable, it will be called to get the response and encode it as json
@ -388,7 +391,7 @@ async def open_tunnel_client(
cfg: 'config.ConfigurationType',
use_tunnel_handshake: bool = False,
local_port: typing.Optional[int] = None,
) -> typing.AsyncGenerator[
) -> collections.abc.AsyncGenerator[
typing.Tuple[asyncio.StreamReader, asyncio.StreamWriter], None
]:
"""opens an ssl socket to the tunnel server"""
@ -414,6 +417,12 @@ async def open_tunnel_client(
sock.setblocking(False)
await loop.sock_connect(sock, (cfg.listen_address, cfg.listen_port))
await loop.sock_sendall(sock, consts.HANDSHAKE_V1)
# Note, we need an small delay, because the "middle connection", in case of tunnel proc,
# that will simulate the tunnel handshake processor is running over a bufferedReader
# (reads chunks of 4096 bytes). If we don't wait, the handshake will be readed
# and part or all of ssl handshake also.
# With uvloop this seems to be not needed, but with asyncio it is.
await asyncio.sleep(0.05)
# upgrade to ssl
reader, writer = await asyncio.open_connection(
sock=sock, ssl=context, server_hostname=cfg.listen_address
@ -433,7 +442,7 @@ async def tunnel_app_runner(
wait_for_port: bool = False,
args: typing.Optional[typing.List[str]] = None,
**kwargs: typing.Union[str, int, bool],
) -> typing.AsyncGenerator['Process', None]:
) -> collections.abc.AsyncGenerator['Process', None]:
# Ensure we are on src directory
if os.path.basename(os.getcwd()) != 'src':
os.chdir('src')