1
0
mirror of https://github.com/dkmstr/openuds.git synced 2024-12-22 13:34:04 +03:00

Small fixes and more tests. Refactoring.

This commit is contained in:
Adolfo Gómez García 2022-12-20 18:31:04 +01:00
parent fdd8ac00c9
commit 472299f878
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
10 changed files with 171 additions and 58 deletions

View File

@ -0,0 +1,2 @@
htmlcov
.coverage

View File

@ -1,5 +1,8 @@
[run]
dynamic_context = test_function
omit =
test/*
/etc/*
[report]
skip_empty = True

View File

@ -1,5 +1,6 @@
[pytest]
addopts = "-s"
#addopts = "-s"
addopts = --cov --cov-report html --cov-config=coverage.ini -n 12
pythonpath = ./src
python_files = tests.py test_*.py *_tests.py
log_format = "%(asctime)s %(levelname)s %(message)s"

View File

@ -114,7 +114,7 @@ class TunnelProtocol(asyncio.Protocol):
async def open_other_side() -> None:
try:
result = await TunnelProtocol.getTicketFromUDS(
result = await TunnelProtocol.get_ticket_from_uds(
self.owner.cfg, ticket, self.source
)
except Exception as e:
@ -277,10 +277,10 @@ class TunnelProtocol(asyncio.Protocol):
logger.debug('Data received: %s', len(data))
self.runner(data) # send data to current runner (command or proxy)
def notifyEnd(self):
def notify_end(self):
if self.notify_ticket:
asyncio.get_event_loop().create_task(
TunnelProtocol.notifyEndToUds(
TunnelProtocol.notify_end_to_uds(
self.owner.cfg, self.notify_ticket, self.stats_manager
)
)
@ -295,7 +295,7 @@ class TunnelProtocol(asyncio.Protocol):
self.other_side.transport.close()
else:
self.stats_manager.close()
self.notifyEnd()
self.notify_end()
# helpers
@staticmethod
@ -324,12 +324,12 @@ class TunnelProtocol(asyncio.Protocol):
int(self.stats_manager.end - self.stats_manager.start),
)
# Notify end to uds
self.notifyEnd()
self.notify_end()
else:
logger.info('TERMINATED %s', self.pretty_source())
@staticmethod
async def _readFromUDS(
async def _read_from_uds(
cfg: config.ConfigurationType,
ticket: bytes,
msg: str,
@ -358,7 +358,7 @@ class TunnelProtocol(asyncio.Protocol):
raise Exception(f'TICKET COMMS ERROR: {ticket.decode()} {msg} {e!s}')
@staticmethod
async def getTicketFromUDS(
async def get_ticket_from_uds(
cfg: config.ConfigurationType, ticket: bytes, address: typing.Tuple[str, int]
) -> typing.MutableMapping[str, typing.Any]:
# Sanity checks
@ -374,13 +374,13 @@ class TunnelProtocol(asyncio.Protocol):
continue # Correctus
raise ValueError(f'TICKET INVALID (char {i} at pos {n})')
return await TunnelProtocol._readFromUDS(cfg, ticket, address[0])
return await TunnelProtocol._read_from_uds(cfg, ticket, address[0])
@staticmethod
async def notifyEndToUds(
async def notify_end_to_uds(
cfg: config.ConfigurationType, ticket: bytes, counter: stats.Stats
) -> None:
await TunnelProtocol._readFromUDS(
await TunnelProtocol._read_from_uds(
cfg,
ticket,
'stop',

View File

@ -30,7 +30,8 @@ Author: Adolfo Gómez, dkmaster at dkmon dot com
'''
import typing
import asyncio
import io
import random
import string
import logging
from unittest import IsolatedAsyncioTestCase, mock
@ -38,8 +39,12 @@ from uds_tunnel import consts
from .utils import tuntools, tools
if typing.TYPE_CHECKING:
from uds_tunnel import config
logger = logging.getLogger(__name__)
class TestUDSTunnelApp(IsolatedAsyncioTestCase):
async def test_run_app_help(self) -> None:
# Executes the app with --help
@ -49,10 +54,49 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
self.assertEqual(process.returncode, 0, f'{stdout!r} {stderr!r}')
self.assertEqual(stderr, b'')
self.assertIn(b'usage: udstunnel', stdout)
async def client_task(
self, cfg: 'config.ConfigurationType', host: str, port: int
) -> None:
received: bytes = b''
callback_invoked: asyncio.Event = asyncio.Event()
def callback(data: bytes) -> None:
nonlocal received
received += data
# if data contains EOS marcker ('STREAM_END'), we are done
if b'STREAM_END' in data:
callback_invoked.set()
async with tools.AsyncTCPServer(
host=host, port=5445, callback=callback
) as server:
# Create a random ticket with valid format
ticket = ''.join(
random.choice(string.ascii_letters + string.digits)
for _ in range(consts.TICKET_LENGTH)
).encode()
# Open and send handshake
async with tuntools.open_tunnel_client(cfg, use_tunnel_handshake=True) as (
creader,
cwriter,
):
# Now open command with ticket
cwriter.write(consts.COMMAND_OPEN)
# fake ticket, consts.TICKET_LENGTH bytes long, letters and numbers. Use a random ticket,
cwriter.write(ticket)
await cwriter.drain()
# Read response, should be ok
data = await creader.read(1024)
self.assertEqual(
data,
consts.RESPONSE_OK,
f'Server host: {host}:{port} - Ticket: {ticket!r} - Response: {data!r}',
)
async def test_run_app_serve(self) -> None:
for host in ('127.0.0.1', '::1'):
async with tuntools.tunnel_app_runner(host, 7777) as process:
print(process)
# Create a "bunch" of servers and clients
pass

View File

@ -66,7 +66,7 @@ class TestTunnel(IsolatedAsyncioTestCase):
# Set timeout to 1 seconds
bad_cmd = bytes(random.randint(0, 255) for _ in range(i)) # Some garbage
logger.info(f'Testing invalid command with {bad_cmd!r}')
async with tuntools.create_test_tunnel(callback=lambda x: None) as cfg:
async with tuntools.create_test_tunnel(callback=lambda x: None, port=7770, remote_port=54555) as cfg:
logger_mock = mock.MagicMock()
with mock.patch('uds_tunnel.tunnel.logger', logger_mock):
# Open connection to tunnel

View File

@ -65,7 +65,7 @@ class TestTunnelHelpers(IsolatedAsyncioTestCase):
# Test some invalid tickets
# Valid ticket are consts.TICKET_LENGTH bytes long, and must be A-Z, a-z, 0-9
with mock.patch(
'uds_tunnel.tunnel.TunnelProtocol._readFromUDS',
'uds_tunnel.tunnel.TunnelProtocol._read_from_uds',
new_callable=tools.AsyncMock,
) as m:
m.side_effect = uds_response
@ -77,7 +77,7 @@ class TestTunnelHelpers(IsolatedAsyncioTestCase):
)
with self.assertRaises(ValueError):
await tunnel.TunnelProtocol.getTicketFromUDS(
await tunnel.TunnelProtocol.get_ticket_from_uds(
cfg, ticket.encode(), conf.CALLER_HOST
)
@ -85,7 +85,7 @@ class TestTunnelHelpers(IsolatedAsyncioTestCase):
for i in range(0, 100):
# Now some requests with valid tickets
# Ensure no exception is raised
ret_value = await tunnel.TunnelProtocol.getTicketFromUDS(
ret_value = await tunnel.TunnelProtocol.get_ticket_from_uds(
cfg, ticket.encode(), conf.CALLER_HOST
)
# Ensure data returned is correct {host, port, notify} from mock
@ -108,7 +108,7 @@ class TestTunnelHelpers(IsolatedAsyncioTestCase):
async def test_notify_end_to_uds_broker(self) -> None:
_, cfg = fixtures.get_config()
with mock.patch(
'uds_tunnel.tunnel.TunnelProtocol._readFromUDS',
'uds_tunnel.tunnel.TunnelProtocol._read_from_uds',
new_callable=tools.AsyncMock,
) as m:
m.side_effect = uds_response
@ -118,7 +118,7 @@ class TestTunnelHelpers(IsolatedAsyncioTestCase):
ticket = conf.NOTIFY_TICKET.encode()
for i in range(0, 100):
await tunnel.TunnelProtocol.notifyEndToUds(cfg, ticket, counter)
await tunnel.TunnelProtocol.notify_end_to_uds(cfg, ticket, counter)
self.assertEqual(m.call_args[0][0], cfg)
self.assertEqual(
@ -159,9 +159,9 @@ class TestTunnelHelpers(IsolatedAsyncioTestCase):
await tools.get(fake_uds_server),
'{"result":"ok"}',
)
# Now, tests _readFromUDS
# Now, tests _read_from_uds
for i in range(100):
ret = await tunnel.TunnelProtocol._readFromUDS(
ret = await tunnel.TunnelProtocol._read_from_uds(
cfg, conf.NOTIFY_TICKET.encode(), 'test', {'param': 'value'}
)
self.assertEqual(ret, {'result': 'ok'})

View File

@ -55,9 +55,9 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
# Remote is not really important in this tests, will fail before using it
async with tuntools.create_tunnel_proc(
host,
7777,
7890, # A port not used by any other test
'127.0.0.1',
12345,
13579, # A port not used by any other test
) as cfg:
for i in range(0, 8192, 128):
# Set timeout to 1 seconds
@ -85,9 +85,9 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
# Remote is not really important in this tests, will return ok before using it (this is a TEST command, not OPEN)
async with tuntools.create_tunnel_proc(
host,
7777,
7891,
'127.0.0.1',
12345,
13581,
) as cfg:
for i in range(10): # Several times
# On full, we need the handshake to be done, before connecting
@ -117,10 +117,7 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
0, consts.TICKET_LENGTH - 1, 4
): # All will fail. Any longer will be processed, and mock will return correct don't matter the ticket
# Ticket must contain only letters and numbers
ticket = ''.join(
random.choice(string.ascii_letters + string.digits)
for _ in range(i)
).encode()
ticket = tuntools.get_correct_ticket(i)
# On full, we need the handshake to be done, before connecting
# Our "test" server will simple "eat" the handshake, but we need to do it
async with tuntools.open_tunnel_client(
@ -161,10 +158,7 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
) as cfg:
for i in range(1):
# Create a random ticket with valid format
ticket = ''.join(
random.choice(string.ascii_letters + string.digits)
for _ in range(consts.TICKET_LENGTH)
).encode()
ticket = tuntools.get_correct_ticket()
# On full, we need the handshake to be done, before connecting
# Our "test" server will simple "eat" the handshake, but we need to do it
async with tuntools.open_tunnel_client(
@ -175,21 +169,53 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
cwriter.write(ticket)
await cwriter.drain()
# Read response
# Read response, should be ok
data = await creader.read(1024)
self.assertEqual(data, consts.RESPONSE_OK, f'Tunnel host: {tunnel_host}, server host: {host}')
self.assertEqual(
data,
consts.RESPONSE_OK,
f'Tunnel host: {tunnel_host}, server host: {host}',
)
# Data sent will be received by server
# One single write will ensure all data is on same packet
test_str = b'Some Random Data' + bytes(random.randint(0, 255) for _ in range(8192)) + b'STREAM_END'
test_str = (
b'Some Random Data'
+ bytes(random.randint(0, 255) for _ in range(8192))
+ b'STREAM_END'
)
# Clean received data
received = b''
# And reset event
callback_invoked.clear()
cwriter.write(test_str)
await cwriter.drain()
# Wait for callback to be invoked
await callback_invoked.wait()
self.assertEqual(received, test_str)
async def test_tunnel_no_remote(self) -> None:
for host in ('127.0.0.1', '::1'):
for tunnel_host in ('127.0.0.1', '::1'):
async with tuntools.create_tunnel_proc(
tunnel_host,
7888,
host,
17222, # Any non used port will do the trick
) as cfg:
ticket = tuntools.get_correct_ticket()
# On full, we need the handshake to be done, before connecting
# Our "test" server will simple "eat" the handshake, but we need to do it
async with tuntools.open_tunnel_client(
cfg, use_tunnel_handshake=True
) as (creader, cwriter):
cwriter.write(consts.COMMAND_OPEN)
# fake ticket, consts.TICKET_LENGTH bytes long, letters and numbers. Use a random ticket,
cwriter.write(ticket)
await cwriter.drain()
# Read response
data = await creader.read(1024)
self.assertEqual(data, b'', f'Tunnel host: {tunnel_host}, server host: {host}')

View File

@ -38,7 +38,6 @@ from unittest import mock
from . import certs
class AsyncMock(mock.MagicMock):
async def __call__(self, *args, **kwargs):
return super().__call__(*args, **kwargs)
@ -111,7 +110,8 @@ class AsyncTCPServer:
port: int
_server: typing.Optional[asyncio.AbstractServer]
_response: typing.Optional[bytes]
_callback: typing.Optional[typing.Callable[[bytes], None]]
_callback: typing.Optional[typing.Callable[[bytes], typing.Optional[bytes]]]
_processor: typing.Optional[typing.Callable[[asyncio.StreamReader, asyncio.StreamWriter], None]]
def __init__(
self,
@ -119,26 +119,30 @@ class AsyncTCPServer:
*,
response: typing.Optional[bytes] = None,
host: str = '127.0.0.1', # ip
callback: typing.Optional[typing.Callable[[bytes], None]] = None,
callback: typing.Optional[typing.Callable[[bytes], typing.Optional[bytes]]] = None,
processor: typing.Optional[typing.Callable[[asyncio.StreamReader, asyncio.StreamWriter], None]] = None,
) -> None:
self.host = host
self.port = port
self._server = None
self._response = response
self._callback = callback
self._processor = processor
self.data = b''
async def _handle(self, reader, writer) -> None:
if self._processor is not None:
self._processor(reader, writer)
return
while True:
data = await reader.read(2048)
if not data:
break
if self._callback:
self._callback(data)
resp = self._callback(data) if self._callback else self._response
if self._response is not None:
if resp is not None:
data = self._response
writer.write(data)
await writer.drain()

View File

@ -30,22 +30,23 @@ Author: Adolfo Gómez, dkmaster at dkmon dot com
'''
import asyncio
import contextlib
import os
import json
import logging
import multiprocessing
import os
import random
import socket
import ssl
import os
import typing
import string
import tempfile
import typing
from unittest import mock
import multiprocessing
import udstunnel
from uds_tunnel import consts, tunnel, stats, config
from uds_tunnel import config, consts, stats, tunnel
from . import certs, conf, fixtures, tools
logger = logging.getLogger(__name__)
if typing.TYPE_CHECKING:
@ -100,7 +101,7 @@ async def create_tunnel_proc(
remote_host: str,
remote_port: int,
*,
response: typing.Optional[typing.Mapping[str, typing.Any]] = None
response: typing.Optional[typing.Mapping[str, typing.Any]] = None,
) -> typing.AsyncGenerator['config.ConfigurationType', None]:
with create_config_file(listen_host, listen_port) as cfgfile:
args = mock.MagicMock()
@ -116,7 +117,7 @@ async def create_tunnel_proc(
response = conf.UDS_GET_TICKET_RESPONSE(remote_host, remote_port)
with mock.patch(
'uds_tunnel.tunnel.TunnelProtocol._readFromUDS',
'uds_tunnel.tunnel.TunnelProtocol._read_from_uds',
new_callable=tools.AsyncMock,
) as m:
m.return_value = response
@ -214,21 +215,26 @@ async def create_tunnel_server(
@contextlib.asynccontextmanager
async def create_test_tunnel(
*, callback: typing.Callable[[bytes], None]
*,
callback: typing.Callable[[bytes], None],
port: typing.Optional[int] = None,
remote_port: typing.Optional[int] = None,
) -> typing.AsyncGenerator['config.ConfigurationType', None]:
# Generate a listening server for testing tunnel
# Prepare the end of the tunnel
async with tools.AsyncTCPServer(port=54876, callback=callback) as server:
async with tools.AsyncTCPServer(
port=remote_port or 54876, callback=callback
) as server:
# Create a tunnel to localhost 13579
# SSl cert for tunnel server
with certs.ssl_context(server.host) as (ssl_ctx, _):
_, cfg = fixtures.get_config(
address=server.host,
port=7777,
port=port or 7777,
ipv6=':' in server.host,
)
with mock.patch(
'uds_tunnel.tunnel.TunnelProtocol._readFromUDS',
'uds_tunnel.tunnel.TunnelProtocol._read_from_uds',
new_callable=tools.AsyncMock,
) as m:
m.return_value = conf.UDS_GET_TICKET_RESPONSE(server.host, server.port)
@ -241,6 +247,26 @@ async def create_test_tunnel(
await tunnel_server.wait_closed()
@contextlib.asynccontextmanager
async def create_fake_broker_server(
response: typing.Mapping[str, typing.Any], port: int = 44443
) -> typing.AsyncGenerator[None, None]:
# crate a fake broker server
# Ignores request, and sends response
resp: bytes = b'HTTP/1.1 200 OK\r\n\r\n' + json.dumps(response).encode()
def callback(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
writer.write(resp)
writer.close()
async with tools.AsyncTCPServer(port=port, response=resp) as server:
try:
yield
finally:
pass # nothing to do
@contextlib.asynccontextmanager
async def open_tunnel_client(
cfg: 'config.ConfigurationType',
@ -284,7 +310,7 @@ async def tunnel_app_runner(
host: typing.Optional[str] = None,
port: typing.Optional[int] = None,
*,
args: typing.Optional[typing.List[str]] = None
args: typing.Optional[typing.List[str]] = None,
) -> typing.AsyncGenerator['Process', None]:
# Ensure we are on src directory
if os.path.basename(os.getcwd()) != 'src':
@ -312,3 +338,10 @@ async def tunnel_app_runner(
if process.returncode is None:
process.terminate()
await process.wait()
def get_correct_ticket(length: int = consts.TICKET_LENGTH) -> bytes:
return ''.join(
random.choice(string.ascii_letters + string.digits)
for _ in range(length)
).encode()