mirror of
https://github.com/dkmstr/openuds.git
synced 2025-01-18 06:03:54 +03:00
Fixes tests and some linting issues.
This commit is contained in:
parent
cb76078758
commit
379ce8a094
@ -67,7 +67,8 @@ ignored-modules=
|
||||
|
||||
# Python code to execute, usually for sys.path manipulation such as
|
||||
# pygtk.require().
|
||||
#init-hook=
|
||||
# So tests linting can find the src directory, where the code is...
|
||||
init-hook='import sys, os; sys.path.append(os.path.join(os.getcwd(), "src"))'
|
||||
|
||||
# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
|
||||
# number of processors available to use, and will cap the count on Windows to
|
||||
|
@ -1,3 +1,4 @@
|
||||
psutil>=5.7.3
|
||||
cryptography
|
||||
aiohttp
|
||||
uvloop
|
||||
|
@ -46,7 +46,6 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
try:
|
||||
import uvloop
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
except ImportError:
|
||||
pass # no uvloop support
|
||||
@ -318,7 +317,7 @@ def tunnel_main(args: 'argparse.Namespace') -> None:
|
||||
if cfg.pidfile:
|
||||
os.unlink(cfg.pidfile)
|
||||
except Exception:
|
||||
pass
|
||||
logger.warning('Could not remove pidfile %s', cfg.pidfile)
|
||||
|
||||
logger.info('FINISHED')
|
||||
|
||||
|
@ -43,6 +43,7 @@ if typing.TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TestUDSTunnelApp(IsolatedAsyncioTestCase):
|
||||
async def client_task(self, host: str, tunnel_port: int, remote_port: int) -> None:
|
||||
received: bytes = b''
|
||||
@ -50,13 +51,9 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
|
||||
# Data sent will be received by server
|
||||
# One single write will ensure all data is on same packet
|
||||
test_str = (
|
||||
b'Some Random Data'
|
||||
+ bytes(random.randint(0, 255) for _ in range(1024)) * 4
|
||||
+ b'STREAM_END'
|
||||
b'Some Random Data' + bytes(random.randint(0, 255) for _ in range(1024)) * 4 + b'STREAM_END' # nosec: just testing data
|
||||
) # length = 16 + 1024 * 4 + 10 = 4122
|
||||
test_response = (
|
||||
bytes(random.randint(48, 127) for _ in range(12))
|
||||
) # length = 12, random printable chars
|
||||
test_response = bytes(random.randint(48, 127) for _ in range(12)) # nosec: length = 12, random printable chars
|
||||
|
||||
def callback(data: bytes) -> typing.Optional[bytes]:
|
||||
nonlocal received
|
||||
@ -127,6 +124,7 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
|
||||
fake_broker_port = 20000
|
||||
tunnel_server_port = fake_broker_port + 1
|
||||
remote_port = fake_broker_port + 2
|
||||
|
||||
# Extracts the port from an string that has bX0bwmbPORTbX0bwmb in it
|
||||
def extract_port(data: bytes) -> int:
|
||||
if b'bX0bwmb' not in data:
|
||||
@ -142,9 +140,7 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
|
||||
async with tuntools.create_fake_broker_server(
|
||||
host,
|
||||
fake_broker_port,
|
||||
response=lambda data: conf.UDS_GET_TICKET_RESPONSE(
|
||||
host, extract_port(data)
|
||||
),
|
||||
response=lambda data: conf.UDS_GET_TICKET_RESPONSE(host, extract_port(data)),
|
||||
) as req_queue:
|
||||
if req_queue is None:
|
||||
raise AssertionError('req_queue is None')
|
||||
@ -160,12 +156,9 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
|
||||
workers=4,
|
||||
command_timeout=16, # Increase command timeout because heavy load we will create
|
||||
) as process:
|
||||
|
||||
# Create a "bunch" of clients
|
||||
tasks = [
|
||||
asyncio.create_task(
|
||||
self.client_task(host, tunnel_server_port, remote_port + i)
|
||||
)
|
||||
asyncio.create_task(self.client_task(host, tunnel_server_port, remote_port + i))
|
||||
async for i in tools.waitable_range(concurrent_tasks)
|
||||
]
|
||||
|
||||
@ -185,7 +178,7 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
|
||||
remote_port = fake_broker_port + 2
|
||||
# Extracts the port from an string that has bX0bwmbPORTbX0bwmb in it
|
||||
|
||||
req_queue: asyncio.Queue[bytes] = asyncio.Queue()
|
||||
req_queue: 'asyncio.Queue[bytes]' = asyncio.Queue()
|
||||
|
||||
def extract_port(data: bytes) -> int:
|
||||
logger.debug('Data: %r', data)
|
||||
@ -206,17 +199,13 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
|
||||
async with tuntools.create_tunnel_proc(
|
||||
host,
|
||||
tunnel_server_port,
|
||||
response=lambda data: conf.UDS_GET_TICKET_RESPONSE(
|
||||
host, extract_port(data)
|
||||
),
|
||||
response=lambda data: conf.UDS_GET_TICKET_RESPONSE(host, extract_port(data)),
|
||||
command_timeout=16, # Increase command timeout because heavy load we will create,
|
||||
global_stats=stats_collector,
|
||||
) as (cfg, _):
|
||||
# Create a "bunch" of clients
|
||||
tasks = [
|
||||
asyncio.create_task(
|
||||
self.client_task(host, tunnel_server_port, remote_port + i)
|
||||
)
|
||||
asyncio.create_task(self.client_task(host, tunnel_server_port, remote_port + i))
|
||||
async for i in tools.waitable_range(concurrent_tasks)
|
||||
]
|
||||
|
||||
@ -230,6 +219,6 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
|
||||
self.assertEqual(req_queue.qsize(), concurrent_tasks * 2)
|
||||
|
||||
# Check stats
|
||||
self.assertEqual(stats_collector.ns.recv, concurrent_tasks*12)
|
||||
self.assertEqual(stats_collector.ns.sent, concurrent_tasks*4122)
|
||||
self.assertEqual(stats_collector.ns.recv, concurrent_tasks * 12)
|
||||
self.assertEqual(stats_collector.ns.sent, concurrent_tasks * 4122)
|
||||
self.assertEqual(stats_collector.ns.total, concurrent_tasks)
|
||||
|
@ -28,7 +28,6 @@
|
||||
'''
|
||||
Author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
'''
|
||||
import typing
|
||||
import random
|
||||
import socket
|
||||
import logging
|
||||
@ -36,7 +35,7 @@ import multiprocessing
|
||||
from unittest import IsolatedAsyncioTestCase, mock
|
||||
|
||||
from udstunnel import process_connection
|
||||
from uds_tunnel import tunnel, consts
|
||||
from uds_tunnel import consts
|
||||
|
||||
from .utils import tuntools
|
||||
|
||||
@ -64,7 +63,7 @@ class TestTunnel(IsolatedAsyncioTestCase):
|
||||
for i in range(0, 100, 10):
|
||||
# Set timeout to 1 seconds
|
||||
bad_cmd = bytes(random.randint(0, 255) for _ in range(i)) # Some garbage
|
||||
logger.info(f'Testing invalid command with {bad_cmd!r}')
|
||||
logger.info('Testing invalid command with %s', bad_cmd)
|
||||
async with tuntools.create_test_tunnel(callback=lambda x: None, port=7770, remote_port=54555, command_timeout=0.1) as cfg:
|
||||
logger_mock = mock.MagicMock()
|
||||
with mock.patch('uds_tunnel.tunnel.logger', logger_mock):
|
||||
|
@ -28,11 +28,10 @@
|
||||
'''
|
||||
Author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
'''
|
||||
import typing
|
||||
import random
|
||||
import asyncio
|
||||
import logging
|
||||
from unittest import IsolatedAsyncioTestCase, mock
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
from uds_tunnel import consts
|
||||
|
||||
|
@ -107,7 +107,7 @@ def sslContext(ip: str) -> typing.Tuple[ssl.SSLContext, str, str]:
|
||||
f.write(cert)
|
||||
# Create SSL context
|
||||
ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
ssl_ctx.options |= ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1
|
||||
ssl_ctx.minimum_version = ssl.TLSVersion.TLSv1_2
|
||||
ssl_ctx.load_cert_chain(certfile=f'{tmpdir}/{tmpname}.pem', password=password)
|
||||
ssl_ctx.check_hostname = False
|
||||
ssl_ctx.set_ciphers('ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384')
|
||||
|
@ -32,6 +32,7 @@ import asyncio
|
||||
import os
|
||||
import ssl
|
||||
import typing
|
||||
import collections.abc
|
||||
import socket
|
||||
import aiohttp
|
||||
import logging
|
||||
@ -149,7 +150,7 @@ class AsyncTCPServer:
|
||||
await self._processor(reader, writer)
|
||||
return
|
||||
while True:
|
||||
data = await reader.read(4096)
|
||||
data = await reader.read(4096) # Care with this and tunnel handshake on testings...
|
||||
if not data:
|
||||
break
|
||||
|
||||
@ -195,7 +196,7 @@ async def wait_for_port(host: str, port: int) -> None:
|
||||
except ConnectionRefusedError:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def waitable_range(len: int, wait: float = 0.0001) -> typing.AsyncGenerator[int, None]:
|
||||
async def waitable_range(len: int, wait: float = 0.0001) -> 'collections.abc.AsyncGenerator[int, None]':
|
||||
for i in range(len):
|
||||
await asyncio.sleep(wait)
|
||||
yield i
|
||||
|
@ -30,6 +30,7 @@ Author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
'''
|
||||
import asyncio
|
||||
import contextlib
|
||||
import collections.abc
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
@ -122,7 +123,7 @@ async def create_tunnel_proc(
|
||||
global_stats: typing.Optional[stats.GlobalStats] = None,
|
||||
# Configuration parameters
|
||||
**kwargs,
|
||||
) -> typing.AsyncGenerator[
|
||||
) -> collections.abc.AsyncGenerator[
|
||||
typing.Tuple['config.ConfigurationType', typing.Optional[asyncio.Queue[bytes]]],
|
||||
None,
|
||||
]:
|
||||
@ -138,7 +139,7 @@ async def create_tunnel_proc(
|
||||
use_fake_http_server (bool, optional): If True, a fake http server will be used instead of a mock. Defaults to False.
|
||||
|
||||
Yields:
|
||||
typing.AsyncGenerator[typing.Tuple[config.ConfigurationType, typing.Optional[asyncio.Queue[bytes]]], None]: A tuple with the configuration
|
||||
collections.abc.AsyncGenerator[typing.Tuple[config.ConfigurationType, typing.Optional[asyncio.Queue[bytes]]], None]: A tuple with the configuration
|
||||
and a queue with the data received by the "fake_http_server" if used, or None if not used
|
||||
"""
|
||||
|
||||
@ -158,7 +159,7 @@ async def create_tunnel_proc(
|
||||
resp = conf.UDS_GET_TICKET_RESPONSE(remote_host, remote_port)
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def provider() -> typing.AsyncGenerator[
|
||||
async def provider() -> collections.abc.AsyncGenerator[
|
||||
typing.Optional[asyncio.Queue[bytes]], None
|
||||
]:
|
||||
async with create_fake_broker_server(
|
||||
@ -172,7 +173,7 @@ async def create_tunnel_proc(
|
||||
else:
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def provider() -> typing.AsyncGenerator[
|
||||
async def provider() -> collections.abc.AsyncGenerator[
|
||||
typing.Optional[asyncio.Queue[bytes]], None
|
||||
]:
|
||||
with mock.patch(
|
||||
@ -219,7 +220,9 @@ async def create_tunnel_proc(
|
||||
# socket and address
|
||||
async def client_connected_cb(reader, writer):
|
||||
# Read the handshake
|
||||
data = await reader.read(1024)
|
||||
# Note: We need a small wait on sender, because this is a bufferedReader
|
||||
# so it will read the handshake and the first bytes of the data (that is the ssl handshake)
|
||||
_ = await reader.read(len(consts.HANDSHAKE_V1))
|
||||
# For testing, we ignore the handshake value
|
||||
# Send the socket to the tunnel
|
||||
own_end.send(
|
||||
@ -297,7 +300,7 @@ async def create_test_tunnel(
|
||||
remote_port: typing.Optional[int] = None,
|
||||
# Configuration parameters
|
||||
**kwargs: typing.Any,
|
||||
) -> typing.AsyncGenerator['config.ConfigurationType', None]:
|
||||
) -> collections.abc.AsyncGenerator['config.ConfigurationType', None]:
|
||||
# Generate a listening server for testing tunnel
|
||||
# Prepare the end of the tunnel
|
||||
async with tools.AsyncTCPServer(
|
||||
@ -337,7 +340,7 @@ async def create_fake_broker_server(
|
||||
typing.Mapping[str, typing.Any],
|
||||
]
|
||||
],
|
||||
) -> typing.AsyncGenerator[asyncio.Queue[bytes], None]:
|
||||
) -> collections.abc.AsyncGenerator[asyncio.Queue[bytes], None]:
|
||||
# crate a fake broker server
|
||||
# Ignores request, and sends response
|
||||
# if is a callable, it will be called to get the response and encode it as json
|
||||
@ -388,7 +391,7 @@ async def open_tunnel_client(
|
||||
cfg: 'config.ConfigurationType',
|
||||
use_tunnel_handshake: bool = False,
|
||||
local_port: typing.Optional[int] = None,
|
||||
) -> typing.AsyncGenerator[
|
||||
) -> collections.abc.AsyncGenerator[
|
||||
typing.Tuple[asyncio.StreamReader, asyncio.StreamWriter], None
|
||||
]:
|
||||
"""opens an ssl socket to the tunnel server"""
|
||||
@ -414,6 +417,12 @@ async def open_tunnel_client(
|
||||
sock.setblocking(False)
|
||||
await loop.sock_connect(sock, (cfg.listen_address, cfg.listen_port))
|
||||
await loop.sock_sendall(sock, consts.HANDSHAKE_V1)
|
||||
# Note, we need an small delay, because the "middle connection", in case of tunnel proc,
|
||||
# that will simulate the tunnel handshake processor is running over a bufferedReader
|
||||
# (reads chunks of 4096 bytes). If we don't wait, the handshake will be readed
|
||||
# and part or all of ssl handshake also.
|
||||
# With uvloop this seems to be not needed, but with asyncio it is.
|
||||
await asyncio.sleep(0.05)
|
||||
# upgrade to ssl
|
||||
reader, writer = await asyncio.open_connection(
|
||||
sock=sock, ssl=context, server_hostname=cfg.listen_address
|
||||
@ -433,7 +442,7 @@ async def tunnel_app_runner(
|
||||
wait_for_port: bool = False,
|
||||
args: typing.Optional[typing.List[str]] = None,
|
||||
**kwargs: typing.Union[str, int, bool],
|
||||
) -> typing.AsyncGenerator['Process', None]:
|
||||
) -> collections.abc.AsyncGenerator['Process', None]:
|
||||
# Ensure we are on src directory
|
||||
if os.path.basename(os.getcwd()) != 'src':
|
||||
os.chdir('src')
|
||||
|
Loading…
x
Reference in New Issue
Block a user