1
0
mirror of https://github.com/dkmstr/openuds.git synced 2024-12-23 17:34:17 +03:00

some fix for concurrent tests

This commit is contained in:
Adolfo Gómez García 2022-12-21 21:06:42 +01:00
parent 97c72ee7ac
commit 19709bfe3b
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
5 changed files with 135 additions and 48 deletions

View File

@ -36,7 +36,7 @@ CONFIGFILE: typing.Final[str] = '/etc/udstunnel.conf' if not DEBUG else 'udstunn
LOGFORMAT: typing.Final[str] = (
'%(levelname)s %(asctime)s %(message)s'
if not DEBUG
else '%(levelname)s %(asctime)s %(message)s'
else '%(levelname)s %(asctime)s %(name)s:%(funcName)s %(lineno)d %(message)s'
)
# MAX Length of read buffer for proxyed requests

View File

@ -64,11 +64,11 @@ if typing.TYPE_CHECKING:
logger = logging.getLogger(__name__)
running: threading.Event = threading.Event()
do_stop: threading.Event = threading.Event()
def stop_signal(signum: int, frame: typing.Any) -> None:
running.clear()
do_stop.set()
logger.debug('SIGNAL %s, frame: %s', signum, frame)
@ -169,7 +169,7 @@ async def tunnel_proc_async(
tasks.append(asyncio.create_task(run_server()))
try:
while tasks and running.is_set():
while tasks and not do_stop.is_set():
to_wait = tasks[:] # Get a copy of the list
# Wait for "to_wait" tasks to finish, stop every 2 seconds to check if we need to stop
done, _ = await asyncio.wait(to_wait, return_when=asyncio.FIRST_COMPLETED, timeout=2)
@ -179,14 +179,16 @@ async def tunnel_proc_async(
if task.exception():
logger.exception('TUNNEL ERROR')
except asyncio.CancelledError:
running.clear() # ensure we stop
logger.info('Task cancelled')
do_stop.set() # ensure we stop
logger.debug('Out of loop, stopping tasks: %s, running: %s', tasks, do_stop.is_set())
# If any task is still running, cancel it
for task in tasks:
task.cancel()
# Wait for all tasks to finish
await asyncio.gather(*tasks, return_exceptions=True)
await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED)
logger.info('PROCESS %s stopped', os.getpid())
@ -199,11 +201,11 @@ def process_connection(
data = client.recv(len(consts.HANDSHAKE_V1))
if data != consts.HANDSHAKE_V1:
raise Exception() # Invalid handshake
raise Exception('Invalid data: {} ({})'.format( addr, data.hex())) # Invalid handshake
conn.send((client, addr))
del client # Ensure socket is controlled on child process
except Exception:
logger.error('HANDSHAKE invalid from %s (%s)', addr, data.hex())
except Exception as e:
logger.error('HANDSHAKE invalid (%s)', e)
# Close Source and continue
client.close()
@ -268,10 +270,9 @@ def tunnel_main(args: 'argparse.Namespace') -> None:
prcs = processes.Processes(tunnel_proc_async, cfg, stats_collector.ns)
running.set() # Signal we are running
with ThreadPoolExecutor(max_workers=256) as executor:
with ThreadPoolExecutor(max_workers=16) as executor:
try:
while running.is_set():
while not do_stop.is_set():
try:
client, addr = sock.accept()
logger.info('CONNECTION from %s', addr)

View File

@ -54,9 +54,16 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
self.assertEqual(stderr, b'')
self.assertIn(b'usage: udstunnel', stdout)
async def client_task(self, host: str, port: int) -> None:
async def client_task(self, host: str, tunnel_port: int, remote_port: int) -> None:
received: bytes = b''
callback_invoked: asyncio.Event = asyncio.Event()
# Data sent will be received by server
# One single write will ensure all data is on same packet
test_str = (
b'Some Random Data'
+ bytes(random.randint(0, 255) for _ in range(1024)) * 4
+ b'STREAM_END'
)
def callback(data: bytes) -> None:
nonlocal received
@ -66,18 +73,20 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
callback_invoked.set()
async with tools.AsyncTCPServer(
host=host, port=5445, callback=callback
host=host, port=remote_port, callback=callback, name='client_task'
) as server:
# Create a random ticket with valid format
ticket = tuntools.get_correct_ticket()
ticket = tuntools.get_correct_ticket(prefix=f'bX0bwmb{remote_port}bX0bwmb')
# Open and send handshake
# Fake config, only needed data for open_tunnel_client
cfg = mock.MagicMock()
cfg.ipv6 = ':' in host
cfg.listen_address = host
cfg.listen_port = port
cfg.listen_port = tunnel_port
async with tuntools.open_tunnel_client(cfg, use_tunnel_handshake=True) as (
async with tuntools.open_tunnel_client(
cfg, local_port=remote_port + 10000, use_tunnel_handshake=True
) as (
creader,
cwriter,
):
@ -89,37 +98,76 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
await cwriter.drain()
# Read response, should be ok
data = await creader.read(1024)
logger.debug('Received response: %r', data)
self.assertEqual(
data,
consts.RESPONSE_OK,
f'Server host: {host}:{port} - Ticket: {ticket!r} - Response: {data!r}',
f'Server host: {host}:{tunnel_port} - Ticket: {ticket!r} - Response: {data!r}',
)
# Clean received data
received = b''
# And reset event
callback_invoked.clear()
cwriter.write(test_str)
await cwriter.drain()
# Close connection
cwriter.close()
# Wait for callback to be invoked
await callback_invoked.wait()
self.assertEqual(received, test_str)
async def test_run_app_serve(self) -> None:
return
port = random.randint(10000, 20000)
concurrent_tasks = 256
fake_broker_port = 20000
tunnel_server_port = fake_broker_port + 1
remote_port = fake_broker_port + 2
# Extracts the port from an string that has bX0bwmbPORTbX0bwmb in it
def extract_port(data: bytes) -> int:
if b'bX0bwmb' not in data:
return 12345 # No port, wil not be used because is an "stop" request
return int(data.split(b'bX0bwmb')[1])
for host in ('127.0.0.1', '::1'):
if ':' in host:
url = f'http://[{host}]:{port}/uds/rest'
url = f'http://[{host}]:{fake_broker_port}/uds/rest'
else:
url = f'http://{host}:{port}/uds/rest'
url = f'http://{host}:{fake_broker_port}/uds/rest'
# Create fake uds broker
async with tuntools.create_fake_broker_server(
host, port, response=conf.UDS_GET_TICKET_RESPONSE(host, port)
) as broker:
host,
fake_broker_port,
response=lambda data: conf.UDS_GET_TICKET_RESPONSE(
host, extract_port(data)
),
) as req_queue:
if req_queue is None:
raise AssertionError('req_queue is None')
async with tuntools.tunnel_app_runner(
host, 7770, uds_server=url,
host,
tunnel_server_port,
wait_for_port=True,
# Tunnel config
uds_server=url,
logfile='/tmp/tunnel_test.log',
loglevel='DEBUG',
workers=4,
) as process:
# Create a "bunch" of clients
tasks = [
asyncio.create_task(self.client_task(host, 7777))
for _ in range(1)
asyncio.create_task(
self.client_task(host, tunnel_server_port, remote_port + i)
)
for i in range(concurrent_tasks)
]
# Wait for all tasks to finish
await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED)
# If any exception was raised, raise it
for task in tasks:
task.result()
# Queue should have all requests (concurrent_tasks*2, one for open and one for close)
self.assertEqual(req_queue.qsize(), concurrent_tasks * 2)

View File

@ -34,10 +34,12 @@ import ssl
import typing
import socket
import aiohttp
import logging
from unittest import mock
from . import certs
logger = logging.getLogger(__name__)
class AsyncMock(mock.MagicMock):
async def __call__(self, *args, **kwargs):
@ -115,6 +117,7 @@ class AsyncTCPServer:
_processor: typing.Optional[
typing.Callable[[asyncio.StreamReader, asyncio.StreamWriter], typing.Awaitable[None]]
]
_name: str # For debug purposes
def __init__(
self,
@ -128,6 +131,7 @@ class AsyncTCPServer:
processor: typing.Optional[
typing.Callable[[asyncio.StreamReader, asyncio.StreamWriter], typing.Awaitable[None]]
] = None,
name: str = 'AsyncTCPServer',
) -> None:
self.host = host
self.port = port
@ -135,15 +139,17 @@ class AsyncTCPServer:
self._response = response
self._callback = callback
self._processor = processor
self._name = name
self.data = b''
async def _handle(self, reader, writer) -> None:
logger.debug('Handling connection for %s', self._name)
if self._processor is not None:
await self._processor(reader, writer)
return
while True:
data = await reader.read(2048)
data = await reader.read(4096)
if not data:
break

View File

@ -145,8 +145,12 @@ async def create_tunnel_proc(
resp = conf.UDS_GET_TICKET_RESPONSE(remote_host, remote_port)
@contextlib.asynccontextmanager
async def provider() -> typing.AsyncGenerator[typing.Optional[asyncio.Queue[bytes]], None]:
async with create_fake_broker_server(listen_host, port, response=resp) as queue:
async def provider() -> typing.AsyncGenerator[
typing.Optional[asyncio.Queue[bytes]], None
]:
async with create_fake_broker_server(
listen_host, port, response=resp
) as queue:
try:
yield queue
finally:
@ -155,7 +159,9 @@ async def create_tunnel_proc(
else:
@contextlib.asynccontextmanager
async def provider() -> typing.AsyncGenerator[typing.Optional[asyncio.Queue[bytes]], None]:
async def provider() -> typing.AsyncGenerator[
typing.Optional[asyncio.Queue[bytes]], None
]:
with mock.patch(
'uds_tunnel.tunnel.TunnelProtocol._read_from_uds',
new_callable=tools.AsyncMock,
@ -187,8 +193,8 @@ async def create_tunnel_proc(
udstunnel.setup_log(cfg)
# Set running flag
udstunnel.running.set()
# Clear the stop flag
udstunnel.do_stop.clear()
# Create the tunnel task
task = asyncio.create_task(
@ -281,7 +287,7 @@ async def create_test_tunnel(
# Generate a listening server for testing tunnel
# Prepare the end of the tunnel
async with tools.AsyncTCPServer(
port=remote_port or 54876, callback=callback
port=remote_port or 54876, callback=callback, name='create_test_tunnel'
) as server:
# Create a tunnel to localhost 13579
# SSl cert for tunnel server
@ -307,30 +313,44 @@ async def create_test_tunnel(
@contextlib.asynccontextmanager
async def create_fake_broker_server(
host: str, port: int, *, response: typing.Mapping[str, typing.Any]
host: str,
port: int,
*,
response: typing.Union[
typing.Callable[[bytes], typing.Mapping[str, typing.Any]],
typing.Mapping[str, typing.Any],
],
) -> typing.AsyncGenerator[asyncio.Queue[bytes], None]:
# crate a fake broker server
# Ignores request, and sends response
# if is a callable, it will be called to get the response and encode it as json
resp: bytes = (
b'HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n\r\n'
+ json.dumps(response).encode()
)
# Future to content the server request
# to content the server request
data: bytes = b''
requests: asyncio.Queue[bytes] = asyncio.Queue()
async def processor(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
nonlocal data
nonlocal response
while readed := await reader.read(1024):
data += readed
# If data ends in \r\n\r\n, we have the full request
if data.endswith(b'\r\n\r\n'):
requests.put_nowait(data)
data = b'' # reset data for next
break
if callable(response):
rr = response(data)
else:
rr = response
resp: bytes = (
b'HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n\r\n'
+ json.dumps(rr).encode()
)
data = b'' # reset data for next
# send response
writer.write(resp)
await writer.drain()
@ -338,7 +358,7 @@ async def create_fake_broker_server(
writer.close()
async with tools.AsyncTCPServer(
host=host, port=port, response=resp, processor=processor
host=host, port=port, processor=processor, name='create_fake_broker_server'
) as server:
try:
yield requests
@ -350,6 +370,7 @@ async def create_fake_broker_server(
async def open_tunnel_client(
cfg: 'config.ConfigurationType',
use_tunnel_handshake: bool = False,
local_port: typing.Optional[int] = None,
) -> typing.AsyncGenerator[
typing.Tuple[asyncio.StreamReader, asyncio.StreamWriter], None
]:
@ -369,6 +390,9 @@ async def open_tunnel_client(
else:
# Open the socket, send handshake and then upgrade to ssl, non blocking
sock = socket.socket(family, socket.SOCK_STREAM)
if local_port:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(('', local_port))
# Set socket to non blocking
sock.setblocking(False)
await loop.sock_connect(sock, (cfg.listen_address, cfg.listen_port))
@ -425,7 +449,15 @@ async def tunnel_app_runner(
await process.wait()
def get_correct_ticket(length: int = consts.TICKET_LENGTH) -> bytes:
return ''.join(
random.choice(string.ascii_letters + string.digits) for _ in range(length)
def get_correct_ticket(
length: int = consts.TICKET_LENGTH, *, prefix: typing.Optional[str] = None
) -> bytes:
"""Returns a ticket with the correct length"""
prefix = prefix or ''
return (
''.join(
random.choice(string.ascii_letters + string.digits)
for _ in range(length - len(prefix))
).encode()
+ prefix.encode()
)