mirror of
https://github.com/dkmstr/openuds.git
synced 2024-12-23 17:34:17 +03:00
some fix for concurrent tests
This commit is contained in:
parent
97c72ee7ac
commit
19709bfe3b
@ -36,7 +36,7 @@ CONFIGFILE: typing.Final[str] = '/etc/udstunnel.conf' if not DEBUG else 'udstunn
|
||||
LOGFORMAT: typing.Final[str] = (
|
||||
'%(levelname)s %(asctime)s %(message)s'
|
||||
if not DEBUG
|
||||
else '%(levelname)s %(asctime)s %(message)s'
|
||||
else '%(levelname)s %(asctime)s %(name)s:%(funcName)s %(lineno)d %(message)s'
|
||||
)
|
||||
|
||||
# MAX Length of read buffer for proxyed requests
|
||||
|
@ -64,11 +64,11 @@ if typing.TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
running: threading.Event = threading.Event()
|
||||
do_stop: threading.Event = threading.Event()
|
||||
|
||||
|
||||
def stop_signal(signum: int, frame: typing.Any) -> None:
|
||||
running.clear()
|
||||
do_stop.set()
|
||||
logger.debug('SIGNAL %s, frame: %s', signum, frame)
|
||||
|
||||
|
||||
@ -169,7 +169,7 @@ async def tunnel_proc_async(
|
||||
tasks.append(asyncio.create_task(run_server()))
|
||||
|
||||
try:
|
||||
while tasks and running.is_set():
|
||||
while tasks and not do_stop.is_set():
|
||||
to_wait = tasks[:] # Get a copy of the list
|
||||
# Wait for "to_wait" tasks to finish, stop every 2 seconds to check if we need to stop
|
||||
done, _ = await asyncio.wait(to_wait, return_when=asyncio.FIRST_COMPLETED, timeout=2)
|
||||
@ -179,14 +179,16 @@ async def tunnel_proc_async(
|
||||
if task.exception():
|
||||
logger.exception('TUNNEL ERROR')
|
||||
except asyncio.CancelledError:
|
||||
running.clear() # ensure we stop
|
||||
logger.info('Task cancelled')
|
||||
do_stop.set() # ensure we stop
|
||||
|
||||
logger.debug('Out of loop, stopping tasks: %s, running: %s', tasks, do_stop.is_set())
|
||||
# If any task is still running, cancel it
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
|
||||
# Wait for all tasks to finish
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED)
|
||||
|
||||
logger.info('PROCESS %s stopped', os.getpid())
|
||||
|
||||
@ -199,11 +201,11 @@ def process_connection(
|
||||
data = client.recv(len(consts.HANDSHAKE_V1))
|
||||
|
||||
if data != consts.HANDSHAKE_V1:
|
||||
raise Exception() # Invalid handshake
|
||||
raise Exception('Invalid data: {} ({})'.format( addr, data.hex())) # Invalid handshake
|
||||
conn.send((client, addr))
|
||||
del client # Ensure socket is controlled on child process
|
||||
except Exception:
|
||||
logger.error('HANDSHAKE invalid from %s (%s)', addr, data.hex())
|
||||
except Exception as e:
|
||||
logger.error('HANDSHAKE invalid (%s)', e)
|
||||
# Close Source and continue
|
||||
client.close()
|
||||
|
||||
@ -268,10 +270,9 @@ def tunnel_main(args: 'argparse.Namespace') -> None:
|
||||
|
||||
prcs = processes.Processes(tunnel_proc_async, cfg, stats_collector.ns)
|
||||
|
||||
running.set() # Signal we are running
|
||||
with ThreadPoolExecutor(max_workers=256) as executor:
|
||||
with ThreadPoolExecutor(max_workers=16) as executor:
|
||||
try:
|
||||
while running.is_set():
|
||||
while not do_stop.is_set():
|
||||
try:
|
||||
client, addr = sock.accept()
|
||||
logger.info('CONNECTION from %s', addr)
|
||||
|
@ -54,9 +54,16 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
|
||||
self.assertEqual(stderr, b'')
|
||||
self.assertIn(b'usage: udstunnel', stdout)
|
||||
|
||||
async def client_task(self, host: str, port: int) -> None:
|
||||
async def client_task(self, host: str, tunnel_port: int, remote_port: int) -> None:
|
||||
received: bytes = b''
|
||||
callback_invoked: asyncio.Event = asyncio.Event()
|
||||
# Data sent will be received by server
|
||||
# One single write will ensure all data is on same packet
|
||||
test_str = (
|
||||
b'Some Random Data'
|
||||
+ bytes(random.randint(0, 255) for _ in range(1024)) * 4
|
||||
+ b'STREAM_END'
|
||||
)
|
||||
|
||||
def callback(data: bytes) -> None:
|
||||
nonlocal received
|
||||
@ -66,18 +73,20 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
|
||||
callback_invoked.set()
|
||||
|
||||
async with tools.AsyncTCPServer(
|
||||
host=host, port=5445, callback=callback
|
||||
host=host, port=remote_port, callback=callback, name='client_task'
|
||||
) as server:
|
||||
# Create a random ticket with valid format
|
||||
ticket = tuntools.get_correct_ticket()
|
||||
ticket = tuntools.get_correct_ticket(prefix=f'bX0bwmb{remote_port}bX0bwmb')
|
||||
# Open and send handshake
|
||||
# Fake config, only needed data for open_tunnel_client
|
||||
cfg = mock.MagicMock()
|
||||
cfg.ipv6 = ':' in host
|
||||
cfg.listen_address = host
|
||||
cfg.listen_port = port
|
||||
cfg.listen_port = tunnel_port
|
||||
|
||||
async with tuntools.open_tunnel_client(cfg, use_tunnel_handshake=True) as (
|
||||
async with tuntools.open_tunnel_client(
|
||||
cfg, local_port=remote_port + 10000, use_tunnel_handshake=True
|
||||
) as (
|
||||
creader,
|
||||
cwriter,
|
||||
):
|
||||
@ -89,37 +98,76 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
|
||||
await cwriter.drain()
|
||||
# Read response, should be ok
|
||||
data = await creader.read(1024)
|
||||
logger.debug('Received response: %r', data)
|
||||
self.assertEqual(
|
||||
data,
|
||||
consts.RESPONSE_OK,
|
||||
f'Server host: {host}:{port} - Ticket: {ticket!r} - Response: {data!r}',
|
||||
f'Server host: {host}:{tunnel_port} - Ticket: {ticket!r} - Response: {data!r}',
|
||||
)
|
||||
# Clean received data
|
||||
received = b''
|
||||
# And reset event
|
||||
callback_invoked.clear()
|
||||
|
||||
cwriter.write(test_str)
|
||||
await cwriter.drain()
|
||||
# Close connection
|
||||
cwriter.close()
|
||||
|
||||
# Wait for callback to be invoked
|
||||
await callback_invoked.wait()
|
||||
self.assertEqual(received, test_str)
|
||||
|
||||
async def test_run_app_serve(self) -> None:
|
||||
return
|
||||
port = random.randint(10000, 20000)
|
||||
concurrent_tasks = 256
|
||||
fake_broker_port = 20000
|
||||
tunnel_server_port = fake_broker_port + 1
|
||||
remote_port = fake_broker_port + 2
|
||||
# Extracts the port from an string that has bX0bwmbPORTbX0bwmb in it
|
||||
def extract_port(data: bytes) -> int:
|
||||
if b'bX0bwmb' not in data:
|
||||
return 12345 # No port, wil not be used because is an "stop" request
|
||||
return int(data.split(b'bX0bwmb')[1])
|
||||
|
||||
for host in ('127.0.0.1', '::1'):
|
||||
if ':' in host:
|
||||
url = f'http://[{host}]:{port}/uds/rest'
|
||||
url = f'http://[{host}]:{fake_broker_port}/uds/rest'
|
||||
else:
|
||||
url = f'http://{host}:{port}/uds/rest'
|
||||
url = f'http://{host}:{fake_broker_port}/uds/rest'
|
||||
# Create fake uds broker
|
||||
async with tuntools.create_fake_broker_server(
|
||||
host, port, response=conf.UDS_GET_TICKET_RESPONSE(host, port)
|
||||
) as broker:
|
||||
host,
|
||||
fake_broker_port,
|
||||
response=lambda data: conf.UDS_GET_TICKET_RESPONSE(
|
||||
host, extract_port(data)
|
||||
),
|
||||
) as req_queue:
|
||||
if req_queue is None:
|
||||
raise AssertionError('req_queue is None')
|
||||
|
||||
async with tuntools.tunnel_app_runner(
|
||||
host, 7770, uds_server=url,
|
||||
host,
|
||||
tunnel_server_port,
|
||||
wait_for_port=True,
|
||||
# Tunnel config
|
||||
uds_server=url,
|
||||
logfile='/tmp/tunnel_test.log',
|
||||
loglevel='DEBUG',
|
||||
workers=4,
|
||||
) as process:
|
||||
# Create a "bunch" of clients
|
||||
tasks = [
|
||||
asyncio.create_task(self.client_task(host, 7777))
|
||||
for _ in range(1)
|
||||
asyncio.create_task(
|
||||
self.client_task(host, tunnel_server_port, remote_port + i)
|
||||
)
|
||||
for i in range(concurrent_tasks)
|
||||
]
|
||||
|
||||
# Wait for all tasks to finish
|
||||
await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED)
|
||||
|
||||
# If any exception was raised, raise it
|
||||
for task in tasks:
|
||||
task.result()
|
||||
# Queue should have all requests (concurrent_tasks*2, one for open and one for close)
|
||||
self.assertEqual(req_queue.qsize(), concurrent_tasks * 2)
|
||||
|
@ -34,10 +34,12 @@ import ssl
|
||||
import typing
|
||||
import socket
|
||||
import aiohttp
|
||||
import logging
|
||||
from unittest import mock
|
||||
|
||||
from . import certs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class AsyncMock(mock.MagicMock):
|
||||
async def __call__(self, *args, **kwargs):
|
||||
@ -115,6 +117,7 @@ class AsyncTCPServer:
|
||||
_processor: typing.Optional[
|
||||
typing.Callable[[asyncio.StreamReader, asyncio.StreamWriter], typing.Awaitable[None]]
|
||||
]
|
||||
_name: str # For debug purposes
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -128,6 +131,7 @@ class AsyncTCPServer:
|
||||
processor: typing.Optional[
|
||||
typing.Callable[[asyncio.StreamReader, asyncio.StreamWriter], typing.Awaitable[None]]
|
||||
] = None,
|
||||
name: str = 'AsyncTCPServer',
|
||||
) -> None:
|
||||
self.host = host
|
||||
self.port = port
|
||||
@ -135,15 +139,17 @@ class AsyncTCPServer:
|
||||
self._response = response
|
||||
self._callback = callback
|
||||
self._processor = processor
|
||||
self._name = name
|
||||
|
||||
self.data = b''
|
||||
|
||||
async def _handle(self, reader, writer) -> None:
|
||||
logger.debug('Handling connection for %s', self._name)
|
||||
if self._processor is not None:
|
||||
await self._processor(reader, writer)
|
||||
return
|
||||
while True:
|
||||
data = await reader.read(2048)
|
||||
data = await reader.read(4096)
|
||||
if not data:
|
||||
break
|
||||
|
||||
|
@ -145,8 +145,12 @@ async def create_tunnel_proc(
|
||||
resp = conf.UDS_GET_TICKET_RESPONSE(remote_host, remote_port)
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def provider() -> typing.AsyncGenerator[typing.Optional[asyncio.Queue[bytes]], None]:
|
||||
async with create_fake_broker_server(listen_host, port, response=resp) as queue:
|
||||
async def provider() -> typing.AsyncGenerator[
|
||||
typing.Optional[asyncio.Queue[bytes]], None
|
||||
]:
|
||||
async with create_fake_broker_server(
|
||||
listen_host, port, response=resp
|
||||
) as queue:
|
||||
try:
|
||||
yield queue
|
||||
finally:
|
||||
@ -155,7 +159,9 @@ async def create_tunnel_proc(
|
||||
else:
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def provider() -> typing.AsyncGenerator[typing.Optional[asyncio.Queue[bytes]], None]:
|
||||
async def provider() -> typing.AsyncGenerator[
|
||||
typing.Optional[asyncio.Queue[bytes]], None
|
||||
]:
|
||||
with mock.patch(
|
||||
'uds_tunnel.tunnel.TunnelProtocol._read_from_uds',
|
||||
new_callable=tools.AsyncMock,
|
||||
@ -187,8 +193,8 @@ async def create_tunnel_proc(
|
||||
|
||||
udstunnel.setup_log(cfg)
|
||||
|
||||
# Set running flag
|
||||
udstunnel.running.set()
|
||||
# Clear the stop flag
|
||||
udstunnel.do_stop.clear()
|
||||
|
||||
# Create the tunnel task
|
||||
task = asyncio.create_task(
|
||||
@ -281,7 +287,7 @@ async def create_test_tunnel(
|
||||
# Generate a listening server for testing tunnel
|
||||
# Prepare the end of the tunnel
|
||||
async with tools.AsyncTCPServer(
|
||||
port=remote_port or 54876, callback=callback
|
||||
port=remote_port or 54876, callback=callback, name='create_test_tunnel'
|
||||
) as server:
|
||||
# Create a tunnel to localhost 13579
|
||||
# SSl cert for tunnel server
|
||||
@ -307,30 +313,44 @@ async def create_test_tunnel(
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def create_fake_broker_server(
|
||||
host: str, port: int, *, response: typing.Mapping[str, typing.Any]
|
||||
host: str,
|
||||
port: int,
|
||||
*,
|
||||
response: typing.Union[
|
||||
typing.Callable[[bytes], typing.Mapping[str, typing.Any]],
|
||||
typing.Mapping[str, typing.Any],
|
||||
],
|
||||
) -> typing.AsyncGenerator[asyncio.Queue[bytes], None]:
|
||||
# crate a fake broker server
|
||||
# Ignores request, and sends response
|
||||
# if is a callable, it will be called to get the response and encode it as json
|
||||
|
||||
resp: bytes = (
|
||||
b'HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n\r\n'
|
||||
+ json.dumps(response).encode()
|
||||
)
|
||||
|
||||
# Future to content the server request
|
||||
# to content the server request
|
||||
data: bytes = b''
|
||||
requests: asyncio.Queue[bytes] = asyncio.Queue()
|
||||
|
||||
async def processor(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
|
||||
nonlocal data
|
||||
nonlocal response
|
||||
|
||||
while readed := await reader.read(1024):
|
||||
data += readed
|
||||
# If data ends in \r\n\r\n, we have the full request
|
||||
if data.endswith(b'\r\n\r\n'):
|
||||
requests.put_nowait(data)
|
||||
data = b'' # reset data for next
|
||||
break
|
||||
|
||||
if callable(response):
|
||||
rr = response(data)
|
||||
else:
|
||||
rr = response
|
||||
|
||||
resp: bytes = (
|
||||
b'HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n\r\n'
|
||||
+ json.dumps(rr).encode()
|
||||
)
|
||||
|
||||
data = b'' # reset data for next
|
||||
# send response
|
||||
writer.write(resp)
|
||||
await writer.drain()
|
||||
@ -338,7 +358,7 @@ async def create_fake_broker_server(
|
||||
writer.close()
|
||||
|
||||
async with tools.AsyncTCPServer(
|
||||
host=host, port=port, response=resp, processor=processor
|
||||
host=host, port=port, processor=processor, name='create_fake_broker_server'
|
||||
) as server:
|
||||
try:
|
||||
yield requests
|
||||
@ -350,6 +370,7 @@ async def create_fake_broker_server(
|
||||
async def open_tunnel_client(
|
||||
cfg: 'config.ConfigurationType',
|
||||
use_tunnel_handshake: bool = False,
|
||||
local_port: typing.Optional[int] = None,
|
||||
) -> typing.AsyncGenerator[
|
||||
typing.Tuple[asyncio.StreamReader, asyncio.StreamWriter], None
|
||||
]:
|
||||
@ -369,6 +390,9 @@ async def open_tunnel_client(
|
||||
else:
|
||||
# Open the socket, send handshake and then upgrade to ssl, non blocking
|
||||
sock = socket.socket(family, socket.SOCK_STREAM)
|
||||
if local_port:
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
sock.bind(('', local_port))
|
||||
# Set socket to non blocking
|
||||
sock.setblocking(False)
|
||||
await loop.sock_connect(sock, (cfg.listen_address, cfg.listen_port))
|
||||
@ -425,7 +449,15 @@ async def tunnel_app_runner(
|
||||
await process.wait()
|
||||
|
||||
|
||||
def get_correct_ticket(length: int = consts.TICKET_LENGTH) -> bytes:
|
||||
return ''.join(
|
||||
random.choice(string.ascii_letters + string.digits) for _ in range(length)
|
||||
def get_correct_ticket(
|
||||
length: int = consts.TICKET_LENGTH, *, prefix: typing.Optional[str] = None
|
||||
) -> bytes:
|
||||
"""Returns a ticket with the correct length"""
|
||||
prefix = prefix or ''
|
||||
return (
|
||||
''.join(
|
||||
random.choice(string.ascii_letters + string.digits)
|
||||
for _ in range(length - len(prefix))
|
||||
).encode()
|
||||
+ prefix.encode()
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user