1
0
mirror of https://github.com/dkmstr/openuds.git synced 2025-01-18 06:03:54 +03:00

Fixes tests and some linting issues.

This commit is contained in:
Adolfo Gómez García 2023-05-21 05:35:52 +02:00
parent cb76078758
commit 379ce8a094
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
10 changed files with 42 additions and 44 deletions

View File

@ -67,7 +67,8 @@ ignored-modules=
# Python code to execute, usually for sys.path manipulation such as # Python code to execute, usually for sys.path manipulation such as
# pygtk.require(). # pygtk.require().
#init-hook= # So tests linting can find the src directory, where the code is...
init-hook='import sys, os; sys.path.append(os.path.join(os.getcwd(), "src"))'
# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the # Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
# number of processors available to use, and will cap the count on Windows to # number of processors available to use, and will cap the count on Windows to

View File

@ -1,3 +1,4 @@
psutil>=5.7.3 psutil>=5.7.3
cryptography
aiohttp aiohttp
uvloop uvloop

View File

@ -46,7 +46,6 @@ from concurrent.futures import ThreadPoolExecutor
try: try:
import uvloop import uvloop
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
except ImportError: except ImportError:
pass # no uvloop support pass # no uvloop support
@ -318,7 +317,7 @@ def tunnel_main(args: 'argparse.Namespace') -> None:
if cfg.pidfile: if cfg.pidfile:
os.unlink(cfg.pidfile) os.unlink(cfg.pidfile)
except Exception: except Exception:
pass logger.warning('Could not remove pidfile %s', cfg.pidfile)
logger.info('FINISHED') logger.info('FINISHED')

View File

@ -43,6 +43,7 @@ if typing.TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TestUDSTunnelApp(IsolatedAsyncioTestCase): class TestUDSTunnelApp(IsolatedAsyncioTestCase):
async def client_task(self, host: str, tunnel_port: int, remote_port: int) -> None: async def client_task(self, host: str, tunnel_port: int, remote_port: int) -> None:
received: bytes = b'' received: bytes = b''
@ -50,13 +51,9 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
# Data sent will be received by server # Data sent will be received by server
# One single write will ensure all data is on same packet # One single write will ensure all data is on same packet
test_str = ( test_str = (
b'Some Random Data' b'Some Random Data' + bytes(random.randint(0, 255) for _ in range(1024)) * 4 + b'STREAM_END' # nosec: just testing data
+ bytes(random.randint(0, 255) for _ in range(1024)) * 4
+ b'STREAM_END'
) # length = 16 + 1024 * 4 + 10 = 4122 ) # length = 16 + 1024 * 4 + 10 = 4122
test_response = ( test_response = bytes(random.randint(48, 127) for _ in range(12)) # nosec: length = 12, random printable chars
bytes(random.randint(48, 127) for _ in range(12))
) # length = 12, random printable chars
def callback(data: bytes) -> typing.Optional[bytes]: def callback(data: bytes) -> typing.Optional[bytes]:
nonlocal received nonlocal received
@ -127,6 +124,7 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
fake_broker_port = 20000 fake_broker_port = 20000
tunnel_server_port = fake_broker_port + 1 tunnel_server_port = fake_broker_port + 1
remote_port = fake_broker_port + 2 remote_port = fake_broker_port + 2
# Extracts the port from an string that has bX0bwmbPORTbX0bwmb in it # Extracts the port from an string that has bX0bwmbPORTbX0bwmb in it
def extract_port(data: bytes) -> int: def extract_port(data: bytes) -> int:
if b'bX0bwmb' not in data: if b'bX0bwmb' not in data:
@ -142,9 +140,7 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
async with tuntools.create_fake_broker_server( async with tuntools.create_fake_broker_server(
host, host,
fake_broker_port, fake_broker_port,
response=lambda data: conf.UDS_GET_TICKET_RESPONSE( response=lambda data: conf.UDS_GET_TICKET_RESPONSE(host, extract_port(data)),
host, extract_port(data)
),
) as req_queue: ) as req_queue:
if req_queue is None: if req_queue is None:
raise AssertionError('req_queue is None') raise AssertionError('req_queue is None')
@ -160,12 +156,9 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
workers=4, workers=4,
command_timeout=16, # Increase command timeout because heavy load we will create command_timeout=16, # Increase command timeout because heavy load we will create
) as process: ) as process:
# Create a "bunch" of clients # Create a "bunch" of clients
tasks = [ tasks = [
asyncio.create_task( asyncio.create_task(self.client_task(host, tunnel_server_port, remote_port + i))
self.client_task(host, tunnel_server_port, remote_port + i)
)
async for i in tools.waitable_range(concurrent_tasks) async for i in tools.waitable_range(concurrent_tasks)
] ]
@ -185,7 +178,7 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
remote_port = fake_broker_port + 2 remote_port = fake_broker_port + 2
# Extracts the port from an string that has bX0bwmbPORTbX0bwmb in it # Extracts the port from an string that has bX0bwmbPORTbX0bwmb in it
req_queue: asyncio.Queue[bytes] = asyncio.Queue() req_queue: 'asyncio.Queue[bytes]' = asyncio.Queue()
def extract_port(data: bytes) -> int: def extract_port(data: bytes) -> int:
logger.debug('Data: %r', data) logger.debug('Data: %r', data)
@ -206,17 +199,13 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
async with tuntools.create_tunnel_proc( async with tuntools.create_tunnel_proc(
host, host,
tunnel_server_port, tunnel_server_port,
response=lambda data: conf.UDS_GET_TICKET_RESPONSE( response=lambda data: conf.UDS_GET_TICKET_RESPONSE(host, extract_port(data)),
host, extract_port(data)
),
command_timeout=16, # Increase command timeout because heavy load we will create, command_timeout=16, # Increase command timeout because heavy load we will create,
global_stats=stats_collector, global_stats=stats_collector,
) as (cfg, _): ) as (cfg, _):
# Create a "bunch" of clients # Create a "bunch" of clients
tasks = [ tasks = [
asyncio.create_task( asyncio.create_task(self.client_task(host, tunnel_server_port, remote_port + i))
self.client_task(host, tunnel_server_port, remote_port + i)
)
async for i in tools.waitable_range(concurrent_tasks) async for i in tools.waitable_range(concurrent_tasks)
] ]
@ -228,8 +217,8 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
# Queue should have all requests (concurrent_tasks*2, one for open and one for close) # Queue should have all requests (concurrent_tasks*2, one for open and one for close)
self.assertEqual(req_queue.qsize(), concurrent_tasks * 2) self.assertEqual(req_queue.qsize(), concurrent_tasks * 2)
# Check stats # Check stats
self.assertEqual(stats_collector.ns.recv, concurrent_tasks*12) self.assertEqual(stats_collector.ns.recv, concurrent_tasks * 12)
self.assertEqual(stats_collector.ns.sent, concurrent_tasks*4122) self.assertEqual(stats_collector.ns.sent, concurrent_tasks * 4122)
self.assertEqual(stats_collector.ns.total, concurrent_tasks) self.assertEqual(stats_collector.ns.total, concurrent_tasks)

View File

@ -49,7 +49,7 @@ class TestConfigFile(TestCase):
values['logsize'] = values['logsize'] * 1024 * 1024 values['logsize'] = values['logsize'] * 1024 * 1024
values['listen_address'] = values['address'] values['listen_address'] = values['address']
values['listen_port'] = values['port'] values['listen_port'] = values['port']
del values['address'] del values['address']
del values['port'] del values['port']
# Ensure data is correct # Ensure data is correct

View File

@ -28,7 +28,6 @@
''' '''
Author: Adolfo Gómez, dkmaster at dkmon dot com Author: Adolfo Gómez, dkmaster at dkmon dot com
''' '''
import typing
import random import random
import socket import socket
import logging import logging
@ -36,7 +35,7 @@ import multiprocessing
from unittest import IsolatedAsyncioTestCase, mock from unittest import IsolatedAsyncioTestCase, mock
from udstunnel import process_connection from udstunnel import process_connection
from uds_tunnel import tunnel, consts from uds_tunnel import consts
from .utils import tuntools from .utils import tuntools
@ -64,7 +63,7 @@ class TestTunnel(IsolatedAsyncioTestCase):
for i in range(0, 100, 10): for i in range(0, 100, 10):
# Set timeout to 1 seconds # Set timeout to 1 seconds
bad_cmd = bytes(random.randint(0, 255) for _ in range(i)) # Some garbage bad_cmd = bytes(random.randint(0, 255) for _ in range(i)) # Some garbage
logger.info(f'Testing invalid command with {bad_cmd!r}') logger.info('Testing invalid command with %s', bad_cmd)
async with tuntools.create_test_tunnel(callback=lambda x: None, port=7770, remote_port=54555, command_timeout=0.1) as cfg: async with tuntools.create_test_tunnel(callback=lambda x: None, port=7770, remote_port=54555, command_timeout=0.1) as cfg:
logger_mock = mock.MagicMock() logger_mock = mock.MagicMock()
with mock.patch('uds_tunnel.tunnel.logger', logger_mock): with mock.patch('uds_tunnel.tunnel.logger', logger_mock):

View File

@ -28,11 +28,10 @@
''' '''
Author: Adolfo Gómez, dkmaster at dkmon dot com Author: Adolfo Gómez, dkmaster at dkmon dot com
''' '''
import typing
import random import random
import asyncio import asyncio
import logging import logging
from unittest import IsolatedAsyncioTestCase, mock from unittest import IsolatedAsyncioTestCase
from uds_tunnel import consts from uds_tunnel import consts

View File

@ -107,7 +107,7 @@ def sslContext(ip: str) -> typing.Tuple[ssl.SSLContext, str, str]:
f.write(cert) f.write(cert)
# Create SSL context # Create SSL context
ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
ssl_ctx.options |= ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 ssl_ctx.minimum_version = ssl.TLSVersion.TLSv1_2
ssl_ctx.load_cert_chain(certfile=f'{tmpdir}/{tmpname}.pem', password=password) ssl_ctx.load_cert_chain(certfile=f'{tmpdir}/{tmpname}.pem', password=password)
ssl_ctx.check_hostname = False ssl_ctx.check_hostname = False
ssl_ctx.set_ciphers('ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384') ssl_ctx.set_ciphers('ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384')

View File

@ -32,6 +32,7 @@ import asyncio
import os import os
import ssl import ssl
import typing import typing
import collections.abc
import socket import socket
import aiohttp import aiohttp
import logging import logging
@ -149,7 +150,7 @@ class AsyncTCPServer:
await self._processor(reader, writer) await self._processor(reader, writer)
return return
while True: while True:
data = await reader.read(4096) data = await reader.read(4096) # Care with this and tunnel handshake on testings...
if not data: if not data:
break break
@ -195,7 +196,7 @@ async def wait_for_port(host: str, port: int) -> None:
except ConnectionRefusedError: except ConnectionRefusedError:
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
async def waitable_range(len: int, wait: float = 0.0001) -> typing.AsyncGenerator[int, None]: async def waitable_range(len: int, wait: float = 0.0001) -> 'collections.abc.AsyncGenerator[int, None]':
for i in range(len): for i in range(len):
await asyncio.sleep(wait) await asyncio.sleep(wait)
yield i yield i

View File

@ -30,6 +30,7 @@ Author: Adolfo Gómez, dkmaster at dkmon dot com
''' '''
import asyncio import asyncio
import contextlib import contextlib
import collections.abc
import json import json
import logging import logging
import multiprocessing import multiprocessing
@ -122,7 +123,7 @@ async def create_tunnel_proc(
global_stats: typing.Optional[stats.GlobalStats] = None, global_stats: typing.Optional[stats.GlobalStats] = None,
# Configuration parameters # Configuration parameters
**kwargs, **kwargs,
) -> typing.AsyncGenerator[ ) -> collections.abc.AsyncGenerator[
typing.Tuple['config.ConfigurationType', typing.Optional[asyncio.Queue[bytes]]], typing.Tuple['config.ConfigurationType', typing.Optional[asyncio.Queue[bytes]]],
None, None,
]: ]:
@ -138,7 +139,7 @@ async def create_tunnel_proc(
use_fake_http_server (bool, optional): If True, a fake http server will be used instead of a mock. Defaults to False. use_fake_http_server (bool, optional): If True, a fake http server will be used instead of a mock. Defaults to False.
Yields: Yields:
typing.AsyncGenerator[typing.Tuple[config.ConfigurationType, typing.Optional[asyncio.Queue[bytes]]], None]: A tuple with the configuration collections.abc.AsyncGenerator[typing.Tuple[config.ConfigurationType, typing.Optional[asyncio.Queue[bytes]]], None]: A tuple with the configuration
and a queue with the data received by the "fake_http_server" if used, or None if not used and a queue with the data received by the "fake_http_server" if used, or None if not used
""" """
@ -158,7 +159,7 @@ async def create_tunnel_proc(
resp = conf.UDS_GET_TICKET_RESPONSE(remote_host, remote_port) resp = conf.UDS_GET_TICKET_RESPONSE(remote_host, remote_port)
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def provider() -> typing.AsyncGenerator[ async def provider() -> collections.abc.AsyncGenerator[
typing.Optional[asyncio.Queue[bytes]], None typing.Optional[asyncio.Queue[bytes]], None
]: ]:
async with create_fake_broker_server( async with create_fake_broker_server(
@ -172,7 +173,7 @@ async def create_tunnel_proc(
else: else:
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def provider() -> typing.AsyncGenerator[ async def provider() -> collections.abc.AsyncGenerator[
typing.Optional[asyncio.Queue[bytes]], None typing.Optional[asyncio.Queue[bytes]], None
]: ]:
with mock.patch( with mock.patch(
@ -219,7 +220,9 @@ async def create_tunnel_proc(
# socket and address # socket and address
async def client_connected_cb(reader, writer): async def client_connected_cb(reader, writer):
# Read the handshake # Read the handshake
data = await reader.read(1024) # Note: We need a small wait on sender, because this is a bufferedReader
# so it will read the handshake and the first bytes of the data (that is the ssl handshake)
_ = await reader.read(len(consts.HANDSHAKE_V1))
# For testing, we ignore the handshake value # For testing, we ignore the handshake value
# Send the socket to the tunnel # Send the socket to the tunnel
own_end.send( own_end.send(
@ -297,7 +300,7 @@ async def create_test_tunnel(
remote_port: typing.Optional[int] = None, remote_port: typing.Optional[int] = None,
# Configuration parameters # Configuration parameters
**kwargs: typing.Any, **kwargs: typing.Any,
) -> typing.AsyncGenerator['config.ConfigurationType', None]: ) -> collections.abc.AsyncGenerator['config.ConfigurationType', None]:
# Generate a listening server for testing tunnel # Generate a listening server for testing tunnel
# Prepare the end of the tunnel # Prepare the end of the tunnel
async with tools.AsyncTCPServer( async with tools.AsyncTCPServer(
@ -337,7 +340,7 @@ async def create_fake_broker_server(
typing.Mapping[str, typing.Any], typing.Mapping[str, typing.Any],
] ]
], ],
) -> typing.AsyncGenerator[asyncio.Queue[bytes], None]: ) -> collections.abc.AsyncGenerator[asyncio.Queue[bytes], None]:
# crate a fake broker server # crate a fake broker server
# Ignores request, and sends response # Ignores request, and sends response
# if is a callable, it will be called to get the response and encode it as json # if is a callable, it will be called to get the response and encode it as json
@ -388,7 +391,7 @@ async def open_tunnel_client(
cfg: 'config.ConfigurationType', cfg: 'config.ConfigurationType',
use_tunnel_handshake: bool = False, use_tunnel_handshake: bool = False,
local_port: typing.Optional[int] = None, local_port: typing.Optional[int] = None,
) -> typing.AsyncGenerator[ ) -> collections.abc.AsyncGenerator[
typing.Tuple[asyncio.StreamReader, asyncio.StreamWriter], None typing.Tuple[asyncio.StreamReader, asyncio.StreamWriter], None
]: ]:
"""opens an ssl socket to the tunnel server""" """opens an ssl socket to the tunnel server"""
@ -414,6 +417,12 @@ async def open_tunnel_client(
sock.setblocking(False) sock.setblocking(False)
await loop.sock_connect(sock, (cfg.listen_address, cfg.listen_port)) await loop.sock_connect(sock, (cfg.listen_address, cfg.listen_port))
await loop.sock_sendall(sock, consts.HANDSHAKE_V1) await loop.sock_sendall(sock, consts.HANDSHAKE_V1)
# Note, we need an small delay, because the "middle connection", in case of tunnel proc,
# that will simulate the tunnel handshake processor is running over a bufferedReader
# (reads chunks of 4096 bytes). If we don't wait, the handshake will be readed
# and part or all of ssl handshake also.
# With uvloop this seems to be not needed, but with asyncio it is.
await asyncio.sleep(0.05)
# upgrade to ssl # upgrade to ssl
reader, writer = await asyncio.open_connection( reader, writer = await asyncio.open_connection(
sock=sock, ssl=context, server_hostname=cfg.listen_address sock=sock, ssl=context, server_hostname=cfg.listen_address
@ -433,7 +442,7 @@ async def tunnel_app_runner(
wait_for_port: bool = False, wait_for_port: bool = False,
args: typing.Optional[typing.List[str]] = None, args: typing.Optional[typing.List[str]] = None,
**kwargs: typing.Union[str, int, bool], **kwargs: typing.Union[str, int, bool],
) -> typing.AsyncGenerator['Process', None]: ) -> collections.abc.AsyncGenerator['Process', None]:
# Ensure we are on src directory # Ensure we are on src directory
if os.path.basename(os.getcwd()) != 'src': if os.path.basename(os.getcwd()) != 'src':
os.chdir('src') os.chdir('src')