1
0
mirror of https://github.com/dkmstr/openuds.git synced 2025-01-11 05:17:55 +03:00

Small fixes and more tests. Refactoring.

This commit is contained in:
Adolfo Gómez García 2022-12-20 18:31:04 +01:00
parent fdd8ac00c9
commit 472299f878
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
10 changed files with 171 additions and 58 deletions

View File

@ -0,0 +1,2 @@
htmlcov
.coverage

View File

@ -1,5 +1,8 @@
[run] [run]
dynamic_context = test_function dynamic_context = test_function
omit =
test/*
/etc/*
[report] [report]
skip_empty = True skip_empty = True

View File

@ -1,5 +1,6 @@
[pytest] [pytest]
addopts = "-s" #addopts = "-s"
addopts = --cov --cov-report html --cov-config=coverage.ini -n 12
pythonpath = ./src pythonpath = ./src
python_files = tests.py test_*.py *_tests.py python_files = tests.py test_*.py *_tests.py
log_format = "%(asctime)s %(levelname)s %(message)s" log_format = "%(asctime)s %(levelname)s %(message)s"

View File

@ -114,7 +114,7 @@ class TunnelProtocol(asyncio.Protocol):
async def open_other_side() -> None: async def open_other_side() -> None:
try: try:
result = await TunnelProtocol.getTicketFromUDS( result = await TunnelProtocol.get_ticket_from_uds(
self.owner.cfg, ticket, self.source self.owner.cfg, ticket, self.source
) )
except Exception as e: except Exception as e:
@ -277,10 +277,10 @@ class TunnelProtocol(asyncio.Protocol):
logger.debug('Data received: %s', len(data)) logger.debug('Data received: %s', len(data))
self.runner(data) # send data to current runner (command or proxy) self.runner(data) # send data to current runner (command or proxy)
def notifyEnd(self): def notify_end(self):
if self.notify_ticket: if self.notify_ticket:
asyncio.get_event_loop().create_task( asyncio.get_event_loop().create_task(
TunnelProtocol.notifyEndToUds( TunnelProtocol.notify_end_to_uds(
self.owner.cfg, self.notify_ticket, self.stats_manager self.owner.cfg, self.notify_ticket, self.stats_manager
) )
) )
@ -295,7 +295,7 @@ class TunnelProtocol(asyncio.Protocol):
self.other_side.transport.close() self.other_side.transport.close()
else: else:
self.stats_manager.close() self.stats_manager.close()
self.notifyEnd() self.notify_end()
# helpers # helpers
@staticmethod @staticmethod
@ -324,12 +324,12 @@ class TunnelProtocol(asyncio.Protocol):
int(self.stats_manager.end - self.stats_manager.start), int(self.stats_manager.end - self.stats_manager.start),
) )
# Notify end to uds # Notify end to uds
self.notifyEnd() self.notify_end()
else: else:
logger.info('TERMINATED %s', self.pretty_source()) logger.info('TERMINATED %s', self.pretty_source())
@staticmethod @staticmethod
async def _readFromUDS( async def _read_from_uds(
cfg: config.ConfigurationType, cfg: config.ConfigurationType,
ticket: bytes, ticket: bytes,
msg: str, msg: str,
@ -358,7 +358,7 @@ class TunnelProtocol(asyncio.Protocol):
raise Exception(f'TICKET COMMS ERROR: {ticket.decode()} {msg} {e!s}') raise Exception(f'TICKET COMMS ERROR: {ticket.decode()} {msg} {e!s}')
@staticmethod @staticmethod
async def getTicketFromUDS( async def get_ticket_from_uds(
cfg: config.ConfigurationType, ticket: bytes, address: typing.Tuple[str, int] cfg: config.ConfigurationType, ticket: bytes, address: typing.Tuple[str, int]
) -> typing.MutableMapping[str, typing.Any]: ) -> typing.MutableMapping[str, typing.Any]:
# Sanity checks # Sanity checks
@ -374,13 +374,13 @@ class TunnelProtocol(asyncio.Protocol):
continue # Correctus continue # Correctus
raise ValueError(f'TICKET INVALID (char {i} at pos {n})') raise ValueError(f'TICKET INVALID (char {i} at pos {n})')
return await TunnelProtocol._readFromUDS(cfg, ticket, address[0]) return await TunnelProtocol._read_from_uds(cfg, ticket, address[0])
@staticmethod @staticmethod
async def notifyEndToUds( async def notify_end_to_uds(
cfg: config.ConfigurationType, ticket: bytes, counter: stats.Stats cfg: config.ConfigurationType, ticket: bytes, counter: stats.Stats
) -> None: ) -> None:
await TunnelProtocol._readFromUDS( await TunnelProtocol._read_from_uds(
cfg, cfg,
ticket, ticket,
'stop', 'stop',

View File

@ -30,7 +30,8 @@ Author: Adolfo Gómez, dkmaster at dkmon dot com
''' '''
import typing import typing
import asyncio import asyncio
import io import random
import string
import logging import logging
from unittest import IsolatedAsyncioTestCase, mock from unittest import IsolatedAsyncioTestCase, mock
@ -38,8 +39,12 @@ from uds_tunnel import consts
from .utils import tuntools, tools from .utils import tuntools, tools
if typing.TYPE_CHECKING:
from uds_tunnel import config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TestUDSTunnelApp(IsolatedAsyncioTestCase): class TestUDSTunnelApp(IsolatedAsyncioTestCase):
async def test_run_app_help(self) -> None: async def test_run_app_help(self) -> None:
# Executes the app with --help # Executes the app with --help
@ -50,9 +55,48 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
self.assertEqual(stderr, b'') self.assertEqual(stderr, b'')
self.assertIn(b'usage: udstunnel', stdout) self.assertIn(b'usage: udstunnel', stdout)
async def client_task(
self, cfg: 'config.ConfigurationType', host: str, port: int
) -> None:
received: bytes = b''
callback_invoked: asyncio.Event = asyncio.Event()
def callback(data: bytes) -> None:
nonlocal received
received += data
# if data contains EOS marcker ('STREAM_END'), we are done
if b'STREAM_END' in data:
callback_invoked.set()
async with tools.AsyncTCPServer(
host=host, port=5445, callback=callback
) as server:
# Create a random ticket with valid format
ticket = ''.join(
random.choice(string.ascii_letters + string.digits)
for _ in range(consts.TICKET_LENGTH)
).encode()
# Open and send handshake
async with tuntools.open_tunnel_client(cfg, use_tunnel_handshake=True) as (
creader,
cwriter,
):
# Now open command with ticket
cwriter.write(consts.COMMAND_OPEN)
# fake ticket, consts.TICKET_LENGTH bytes long, letters and numbers. Use a random ticket,
cwriter.write(ticket)
await cwriter.drain()
# Read response, should be ok
data = await creader.read(1024)
self.assertEqual(
data,
consts.RESPONSE_OK,
f'Server host: {host}:{port} - Ticket: {ticket!r} - Response: {data!r}',
)
async def test_run_app_serve(self) -> None: async def test_run_app_serve(self) -> None:
for host in ('127.0.0.1', '::1'): for host in ('127.0.0.1', '::1'):
async with tuntools.tunnel_app_runner(host, 7777) as process: async with tuntools.tunnel_app_runner(host, 7777) as process:
print(process) # Create a "bunch" of servers and clients
pass

View File

@ -66,7 +66,7 @@ class TestTunnel(IsolatedAsyncioTestCase):
# Set timeout to 1 seconds # Set timeout to 1 seconds
bad_cmd = bytes(random.randint(0, 255) for _ in range(i)) # Some garbage bad_cmd = bytes(random.randint(0, 255) for _ in range(i)) # Some garbage
logger.info(f'Testing invalid command with {bad_cmd!r}') logger.info(f'Testing invalid command with {bad_cmd!r}')
async with tuntools.create_test_tunnel(callback=lambda x: None) as cfg: async with tuntools.create_test_tunnel(callback=lambda x: None, port=7770, remote_port=54555) as cfg:
logger_mock = mock.MagicMock() logger_mock = mock.MagicMock()
with mock.patch('uds_tunnel.tunnel.logger', logger_mock): with mock.patch('uds_tunnel.tunnel.logger', logger_mock):
# Open connection to tunnel # Open connection to tunnel

View File

@ -65,7 +65,7 @@ class TestTunnelHelpers(IsolatedAsyncioTestCase):
# Test some invalid tickets # Test some invalid tickets
# Valid ticket are consts.TICKET_LENGTH bytes long, and must be A-Z, a-z, 0-9 # Valid ticket are consts.TICKET_LENGTH bytes long, and must be A-Z, a-z, 0-9
with mock.patch( with mock.patch(
'uds_tunnel.tunnel.TunnelProtocol._readFromUDS', 'uds_tunnel.tunnel.TunnelProtocol._read_from_uds',
new_callable=tools.AsyncMock, new_callable=tools.AsyncMock,
) as m: ) as m:
m.side_effect = uds_response m.side_effect = uds_response
@ -77,7 +77,7 @@ class TestTunnelHelpers(IsolatedAsyncioTestCase):
) )
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
await tunnel.TunnelProtocol.getTicketFromUDS( await tunnel.TunnelProtocol.get_ticket_from_uds(
cfg, ticket.encode(), conf.CALLER_HOST cfg, ticket.encode(), conf.CALLER_HOST
) )
@ -85,7 +85,7 @@ class TestTunnelHelpers(IsolatedAsyncioTestCase):
for i in range(0, 100): for i in range(0, 100):
# Now some requests with valid tickets # Now some requests with valid tickets
# Ensure no exception is raised # Ensure no exception is raised
ret_value = await tunnel.TunnelProtocol.getTicketFromUDS( ret_value = await tunnel.TunnelProtocol.get_ticket_from_uds(
cfg, ticket.encode(), conf.CALLER_HOST cfg, ticket.encode(), conf.CALLER_HOST
) )
# Ensure data returned is correct {host, port, notify} from mock # Ensure data returned is correct {host, port, notify} from mock
@ -108,7 +108,7 @@ class TestTunnelHelpers(IsolatedAsyncioTestCase):
async def test_notify_end_to_uds_broker(self) -> None: async def test_notify_end_to_uds_broker(self) -> None:
_, cfg = fixtures.get_config() _, cfg = fixtures.get_config()
with mock.patch( with mock.patch(
'uds_tunnel.tunnel.TunnelProtocol._readFromUDS', 'uds_tunnel.tunnel.TunnelProtocol._read_from_uds',
new_callable=tools.AsyncMock, new_callable=tools.AsyncMock,
) as m: ) as m:
m.side_effect = uds_response m.side_effect = uds_response
@ -118,7 +118,7 @@ class TestTunnelHelpers(IsolatedAsyncioTestCase):
ticket = conf.NOTIFY_TICKET.encode() ticket = conf.NOTIFY_TICKET.encode()
for i in range(0, 100): for i in range(0, 100):
await tunnel.TunnelProtocol.notifyEndToUds(cfg, ticket, counter) await tunnel.TunnelProtocol.notify_end_to_uds(cfg, ticket, counter)
self.assertEqual(m.call_args[0][0], cfg) self.assertEqual(m.call_args[0][0], cfg)
self.assertEqual( self.assertEqual(
@ -159,9 +159,9 @@ class TestTunnelHelpers(IsolatedAsyncioTestCase):
await tools.get(fake_uds_server), await tools.get(fake_uds_server),
'{"result":"ok"}', '{"result":"ok"}',
) )
# Now, tests _readFromUDS # Now, tests _read_from_uds
for i in range(100): for i in range(100):
ret = await tunnel.TunnelProtocol._readFromUDS( ret = await tunnel.TunnelProtocol._read_from_uds(
cfg, conf.NOTIFY_TICKET.encode(), 'test', {'param': 'value'} cfg, conf.NOTIFY_TICKET.encode(), 'test', {'param': 'value'}
) )
self.assertEqual(ret, {'result': 'ok'}) self.assertEqual(ret, {'result': 'ok'})

View File

@ -55,9 +55,9 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
# Remote is not really important in this tests, will fail before using it # Remote is not really important in this tests, will fail before using it
async with tuntools.create_tunnel_proc( async with tuntools.create_tunnel_proc(
host, host,
7777, 7890, # A port not used by any other test
'127.0.0.1', '127.0.0.1',
12345, 13579, # A port not used by any other test
) as cfg: ) as cfg:
for i in range(0, 8192, 128): for i in range(0, 8192, 128):
# Set timeout to 1 seconds # Set timeout to 1 seconds
@ -85,9 +85,9 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
# Remote is not really important in this tests, will return ok before using it (this is a TEST command, not OPEN) # Remote is not really important in this tests, will return ok before using it (this is a TEST command, not OPEN)
async with tuntools.create_tunnel_proc( async with tuntools.create_tunnel_proc(
host, host,
7777, 7891,
'127.0.0.1', '127.0.0.1',
12345, 13581,
) as cfg: ) as cfg:
for i in range(10): # Several times for i in range(10): # Several times
# On full, we need the handshake to be done, before connecting # On full, we need the handshake to be done, before connecting
@ -117,10 +117,7 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
0, consts.TICKET_LENGTH - 1, 4 0, consts.TICKET_LENGTH - 1, 4
): # All will fail. Any longer will be processed, and mock will return correct don't matter the ticket ): # All will fail. Any longer will be processed, and mock will return correct don't matter the ticket
# Ticket must contain only letters and numbers # Ticket must contain only letters and numbers
ticket = ''.join( ticket = tuntools.get_correct_ticket(i)
random.choice(string.ascii_letters + string.digits)
for _ in range(i)
).encode()
# On full, we need the handshake to be done, before connecting # On full, we need the handshake to be done, before connecting
# Our "test" server will simple "eat" the handshake, but we need to do it # Our "test" server will simple "eat" the handshake, but we need to do it
async with tuntools.open_tunnel_client( async with tuntools.open_tunnel_client(
@ -161,10 +158,54 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
) as cfg: ) as cfg:
for i in range(1): for i in range(1):
# Create a random ticket with valid format # Create a random ticket with valid format
ticket = ''.join( ticket = tuntools.get_correct_ticket()
random.choice(string.ascii_letters + string.digits) # On full, we need the handshake to be done, before connecting
for _ in range(consts.TICKET_LENGTH) # Our "test" server will simple "eat" the handshake, but we need to do it
).encode() async with tuntools.open_tunnel_client(
cfg, use_tunnel_handshake=True
) as (creader, cwriter):
cwriter.write(consts.COMMAND_OPEN)
# fake ticket, consts.TICKET_LENGTH bytes long, letters and numbers. Use a random ticket,
cwriter.write(ticket)
await cwriter.drain()
# Read response, should be ok
data = await creader.read(1024)
self.assertEqual(
data,
consts.RESPONSE_OK,
f'Tunnel host: {tunnel_host}, server host: {host}',
)
# Data sent will be received by server
# One single write will ensure all data is on same packet
test_str = (
b'Some Random Data'
+ bytes(random.randint(0, 255) for _ in range(8192))
+ b'STREAM_END'
)
# Clean received data
received = b''
# And reset event
callback_invoked.clear()
cwriter.write(test_str)
await cwriter.drain()
# Wait for callback to be invoked
await callback_invoked.wait()
self.assertEqual(received, test_str)
async def test_tunnel_no_remote(self) -> None:
for host in ('127.0.0.1', '::1'):
for tunnel_host in ('127.0.0.1', '::1'):
async with tuntools.create_tunnel_proc(
tunnel_host,
7888,
host,
17222, # Any non used port will do the trick
) as cfg:
ticket = tuntools.get_correct_ticket()
# On full, we need the handshake to be done, before connecting # On full, we need the handshake to be done, before connecting
# Our "test" server will simple "eat" the handshake, but we need to do it # Our "test" server will simple "eat" the handshake, but we need to do it
async with tuntools.open_tunnel_client( async with tuntools.open_tunnel_client(
@ -177,19 +218,4 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
await cwriter.drain() await cwriter.drain()
# Read response # Read response
data = await creader.read(1024) data = await creader.read(1024)
self.assertEqual(data, consts.RESPONSE_OK, f'Tunnel host: {tunnel_host}, server host: {host}') self.assertEqual(data, b'', f'Tunnel host: {tunnel_host}, server host: {host}')
# Data sent will be received by server
# One single write will ensure all data is on same packet
test_str = b'Some Random Data' + bytes(random.randint(0, 255) for _ in range(8192)) + b'STREAM_END'
# Clean received data
received = b''
# And reset event
callback_invoked.clear()
cwriter.write(test_str)
await cwriter.drain()
# Wait for callback to be invoked
await callback_invoked.wait()
self.assertEqual(received, test_str)

View File

@ -38,7 +38,6 @@ from unittest import mock
from . import certs from . import certs
class AsyncMock(mock.MagicMock): class AsyncMock(mock.MagicMock):
async def __call__(self, *args, **kwargs): async def __call__(self, *args, **kwargs):
return super().__call__(*args, **kwargs) return super().__call__(*args, **kwargs)
@ -111,7 +110,8 @@ class AsyncTCPServer:
port: int port: int
_server: typing.Optional[asyncio.AbstractServer] _server: typing.Optional[asyncio.AbstractServer]
_response: typing.Optional[bytes] _response: typing.Optional[bytes]
_callback: typing.Optional[typing.Callable[[bytes], None]] _callback: typing.Optional[typing.Callable[[bytes], typing.Optional[bytes]]]
_processor: typing.Optional[typing.Callable[[asyncio.StreamReader, asyncio.StreamWriter], None]]
def __init__( def __init__(
self, self,
@ -119,26 +119,30 @@ class AsyncTCPServer:
*, *,
response: typing.Optional[bytes] = None, response: typing.Optional[bytes] = None,
host: str = '127.0.0.1', # ip host: str = '127.0.0.1', # ip
callback: typing.Optional[typing.Callable[[bytes], None]] = None, callback: typing.Optional[typing.Callable[[bytes], typing.Optional[bytes]]] = None,
processor: typing.Optional[typing.Callable[[asyncio.StreamReader, asyncio.StreamWriter], None]] = None,
) -> None: ) -> None:
self.host = host self.host = host
self.port = port self.port = port
self._server = None self._server = None
self._response = response self._response = response
self._callback = callback self._callback = callback
self._processor = processor
self.data = b'' self.data = b''
async def _handle(self, reader, writer) -> None: async def _handle(self, reader, writer) -> None:
if self._processor is not None:
self._processor(reader, writer)
return
while True: while True:
data = await reader.read(2048) data = await reader.read(2048)
if not data: if not data:
break break
if self._callback: resp = self._callback(data) if self._callback else self._response
self._callback(data)
if self._response is not None: if resp is not None:
data = self._response data = self._response
writer.write(data) writer.write(data)
await writer.drain() await writer.drain()

View File

@ -30,22 +30,23 @@ Author: Adolfo Gómez, dkmaster at dkmon dot com
''' '''
import asyncio import asyncio
import contextlib import contextlib
import os import json
import logging import logging
import multiprocessing
import os
import random
import socket import socket
import ssl import ssl
import os import string
import typing
import tempfile import tempfile
import typing
from unittest import mock from unittest import mock
import multiprocessing
import udstunnel import udstunnel
from uds_tunnel import consts, tunnel, stats, config from uds_tunnel import config, consts, stats, tunnel
from . import certs, conf, fixtures, tools from . import certs, conf, fixtures, tools
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
@ -100,7 +101,7 @@ async def create_tunnel_proc(
remote_host: str, remote_host: str,
remote_port: int, remote_port: int,
*, *,
response: typing.Optional[typing.Mapping[str, typing.Any]] = None response: typing.Optional[typing.Mapping[str, typing.Any]] = None,
) -> typing.AsyncGenerator['config.ConfigurationType', None]: ) -> typing.AsyncGenerator['config.ConfigurationType', None]:
with create_config_file(listen_host, listen_port) as cfgfile: with create_config_file(listen_host, listen_port) as cfgfile:
args = mock.MagicMock() args = mock.MagicMock()
@ -116,7 +117,7 @@ async def create_tunnel_proc(
response = conf.UDS_GET_TICKET_RESPONSE(remote_host, remote_port) response = conf.UDS_GET_TICKET_RESPONSE(remote_host, remote_port)
with mock.patch( with mock.patch(
'uds_tunnel.tunnel.TunnelProtocol._readFromUDS', 'uds_tunnel.tunnel.TunnelProtocol._read_from_uds',
new_callable=tools.AsyncMock, new_callable=tools.AsyncMock,
) as m: ) as m:
m.return_value = response m.return_value = response
@ -214,21 +215,26 @@ async def create_tunnel_server(
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def create_test_tunnel( async def create_test_tunnel(
*, callback: typing.Callable[[bytes], None] *,
callback: typing.Callable[[bytes], None],
port: typing.Optional[int] = None,
remote_port: typing.Optional[int] = None,
) -> typing.AsyncGenerator['config.ConfigurationType', None]: ) -> typing.AsyncGenerator['config.ConfigurationType', None]:
# Generate a listening server for testing tunnel # Generate a listening server for testing tunnel
# Prepare the end of the tunnel # Prepare the end of the tunnel
async with tools.AsyncTCPServer(port=54876, callback=callback) as server: async with tools.AsyncTCPServer(
port=remote_port or 54876, callback=callback
) as server:
# Create a tunnel to localhost 13579 # Create a tunnel to localhost 13579
# SSl cert for tunnel server # SSl cert for tunnel server
with certs.ssl_context(server.host) as (ssl_ctx, _): with certs.ssl_context(server.host) as (ssl_ctx, _):
_, cfg = fixtures.get_config( _, cfg = fixtures.get_config(
address=server.host, address=server.host,
port=7777, port=port or 7777,
ipv6=':' in server.host, ipv6=':' in server.host,
) )
with mock.patch( with mock.patch(
'uds_tunnel.tunnel.TunnelProtocol._readFromUDS', 'uds_tunnel.tunnel.TunnelProtocol._read_from_uds',
new_callable=tools.AsyncMock, new_callable=tools.AsyncMock,
) as m: ) as m:
m.return_value = conf.UDS_GET_TICKET_RESPONSE(server.host, server.port) m.return_value = conf.UDS_GET_TICKET_RESPONSE(server.host, server.port)
@ -241,6 +247,26 @@ async def create_test_tunnel(
await tunnel_server.wait_closed() await tunnel_server.wait_closed()
@contextlib.asynccontextmanager
async def create_fake_broker_server(
response: typing.Mapping[str, typing.Any], port: int = 44443
) -> typing.AsyncGenerator[None, None]:
# crate a fake broker server
# Ignores request, and sends response
resp: bytes = b'HTTP/1.1 200 OK\r\n\r\n' + json.dumps(response).encode()
def callback(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
writer.write(resp)
writer.close()
async with tools.AsyncTCPServer(port=port, response=resp) as server:
try:
yield
finally:
pass # nothing to do
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def open_tunnel_client( async def open_tunnel_client(
cfg: 'config.ConfigurationType', cfg: 'config.ConfigurationType',
@ -284,7 +310,7 @@ async def tunnel_app_runner(
host: typing.Optional[str] = None, host: typing.Optional[str] = None,
port: typing.Optional[int] = None, port: typing.Optional[int] = None,
*, *,
args: typing.Optional[typing.List[str]] = None args: typing.Optional[typing.List[str]] = None,
) -> typing.AsyncGenerator['Process', None]: ) -> typing.AsyncGenerator['Process', None]:
# Ensure we are on src directory # Ensure we are on src directory
if os.path.basename(os.getcwd()) != 'src': if os.path.basename(os.getcwd()) != 'src':
@ -312,3 +338,10 @@ async def tunnel_app_runner(
if process.returncode is None: if process.returncode is None:
process.terminate() process.terminate()
await process.wait() await process.wait()
def get_correct_ticket(length: int = consts.TICKET_LENGTH) -> bytes:
return ''.join(
random.choice(string.ascii_letters + string.digits)
for _ in range(length)
).encode()