mirror of
https://github.com/dkmstr/openuds.git
synced 2024-12-22 13:34:04 +03:00
Small fixes and more tests. Refactoring.
This commit is contained in:
parent
fdd8ac00c9
commit
472299f878
2
tunnel-server/.gitignore
vendored
2
tunnel-server/.gitignore
vendored
@ -0,0 +1,2 @@
|
||||
htmlcov
|
||||
.coverage
|
@ -1,5 +1,8 @@
|
||||
[run]
|
||||
dynamic_context = test_function
|
||||
omit =
|
||||
test/*
|
||||
/etc/*
|
||||
|
||||
[report]
|
||||
skip_empty = True
|
||||
|
@ -1,5 +1,6 @@
|
||||
[pytest]
|
||||
addopts = "-s"
|
||||
#addopts = "-s"
|
||||
addopts = --cov --cov-report html --cov-config=coverage.ini -n 12
|
||||
pythonpath = ./src
|
||||
python_files = tests.py test_*.py *_tests.py
|
||||
log_format = "%(asctime)s %(levelname)s %(message)s"
|
||||
|
@ -114,7 +114,7 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
|
||||
async def open_other_side() -> None:
|
||||
try:
|
||||
result = await TunnelProtocol.getTicketFromUDS(
|
||||
result = await TunnelProtocol.get_ticket_from_uds(
|
||||
self.owner.cfg, ticket, self.source
|
||||
)
|
||||
except Exception as e:
|
||||
@ -277,10 +277,10 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
logger.debug('Data received: %s', len(data))
|
||||
self.runner(data) # send data to current runner (command or proxy)
|
||||
|
||||
def notifyEnd(self):
|
||||
def notify_end(self):
|
||||
if self.notify_ticket:
|
||||
asyncio.get_event_loop().create_task(
|
||||
TunnelProtocol.notifyEndToUds(
|
||||
TunnelProtocol.notify_end_to_uds(
|
||||
self.owner.cfg, self.notify_ticket, self.stats_manager
|
||||
)
|
||||
)
|
||||
@ -295,7 +295,7 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
self.other_side.transport.close()
|
||||
else:
|
||||
self.stats_manager.close()
|
||||
self.notifyEnd()
|
||||
self.notify_end()
|
||||
|
||||
# helpers
|
||||
@staticmethod
|
||||
@ -324,12 +324,12 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
int(self.stats_manager.end - self.stats_manager.start),
|
||||
)
|
||||
# Notify end to uds
|
||||
self.notifyEnd()
|
||||
self.notify_end()
|
||||
else:
|
||||
logger.info('TERMINATED %s', self.pretty_source())
|
||||
|
||||
@staticmethod
|
||||
async def _readFromUDS(
|
||||
async def _read_from_uds(
|
||||
cfg: config.ConfigurationType,
|
||||
ticket: bytes,
|
||||
msg: str,
|
||||
@ -358,7 +358,7 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
raise Exception(f'TICKET COMMS ERROR: {ticket.decode()} {msg} {e!s}')
|
||||
|
||||
@staticmethod
|
||||
async def getTicketFromUDS(
|
||||
async def get_ticket_from_uds(
|
||||
cfg: config.ConfigurationType, ticket: bytes, address: typing.Tuple[str, int]
|
||||
) -> typing.MutableMapping[str, typing.Any]:
|
||||
# Sanity checks
|
||||
@ -374,13 +374,13 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
continue # Correctus
|
||||
raise ValueError(f'TICKET INVALID (char {i} at pos {n})')
|
||||
|
||||
return await TunnelProtocol._readFromUDS(cfg, ticket, address[0])
|
||||
return await TunnelProtocol._read_from_uds(cfg, ticket, address[0])
|
||||
|
||||
@staticmethod
|
||||
async def notifyEndToUds(
|
||||
async def notify_end_to_uds(
|
||||
cfg: config.ConfigurationType, ticket: bytes, counter: stats.Stats
|
||||
) -> None:
|
||||
await TunnelProtocol._readFromUDS(
|
||||
await TunnelProtocol._read_from_uds(
|
||||
cfg,
|
||||
ticket,
|
||||
'stop',
|
||||
|
@ -30,7 +30,8 @@ Author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
'''
|
||||
import typing
|
||||
import asyncio
|
||||
import io
|
||||
import random
|
||||
import string
|
||||
import logging
|
||||
from unittest import IsolatedAsyncioTestCase, mock
|
||||
|
||||
@ -38,8 +39,12 @@ from uds_tunnel import consts
|
||||
|
||||
from .utils import tuntools, tools
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from uds_tunnel import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TestUDSTunnelApp(IsolatedAsyncioTestCase):
|
||||
async def test_run_app_help(self) -> None:
|
||||
# Executes the app with --help
|
||||
@ -49,10 +54,49 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
|
||||
self.assertEqual(process.returncode, 0, f'{stdout!r} {stderr!r}')
|
||||
self.assertEqual(stderr, b'')
|
||||
self.assertIn(b'usage: udstunnel', stdout)
|
||||
|
||||
|
||||
async def client_task(
|
||||
self, cfg: 'config.ConfigurationType', host: str, port: int
|
||||
) -> None:
|
||||
received: bytes = b''
|
||||
callback_invoked: asyncio.Event = asyncio.Event()
|
||||
|
||||
def callback(data: bytes) -> None:
|
||||
nonlocal received
|
||||
received += data
|
||||
# if data contains EOS marcker ('STREAM_END'), we are done
|
||||
if b'STREAM_END' in data:
|
||||
callback_invoked.set()
|
||||
|
||||
async with tools.AsyncTCPServer(
|
||||
host=host, port=5445, callback=callback
|
||||
) as server:
|
||||
# Create a random ticket with valid format
|
||||
ticket = ''.join(
|
||||
random.choice(string.ascii_letters + string.digits)
|
||||
for _ in range(consts.TICKET_LENGTH)
|
||||
).encode()
|
||||
# Open and send handshake
|
||||
async with tuntools.open_tunnel_client(cfg, use_tunnel_handshake=True) as (
|
||||
creader,
|
||||
cwriter,
|
||||
):
|
||||
# Now open command with ticket
|
||||
cwriter.write(consts.COMMAND_OPEN)
|
||||
# fake ticket, consts.TICKET_LENGTH bytes long, letters and numbers. Use a random ticket,
|
||||
cwriter.write(ticket)
|
||||
|
||||
await cwriter.drain()
|
||||
# Read response, should be ok
|
||||
data = await creader.read(1024)
|
||||
self.assertEqual(
|
||||
data,
|
||||
consts.RESPONSE_OK,
|
||||
f'Server host: {host}:{port} - Ticket: {ticket!r} - Response: {data!r}',
|
||||
)
|
||||
|
||||
async def test_run_app_serve(self) -> None:
|
||||
for host in ('127.0.0.1', '::1'):
|
||||
async with tuntools.tunnel_app_runner(host, 7777) as process:
|
||||
print(process)
|
||||
|
||||
# Create a "bunch" of servers and clients
|
||||
pass
|
||||
|
@ -66,7 +66,7 @@ class TestTunnel(IsolatedAsyncioTestCase):
|
||||
# Set timeout to 1 seconds
|
||||
bad_cmd = bytes(random.randint(0, 255) for _ in range(i)) # Some garbage
|
||||
logger.info(f'Testing invalid command with {bad_cmd!r}')
|
||||
async with tuntools.create_test_tunnel(callback=lambda x: None) as cfg:
|
||||
async with tuntools.create_test_tunnel(callback=lambda x: None, port=7770, remote_port=54555) as cfg:
|
||||
logger_mock = mock.MagicMock()
|
||||
with mock.patch('uds_tunnel.tunnel.logger', logger_mock):
|
||||
# Open connection to tunnel
|
||||
|
@ -65,7 +65,7 @@ class TestTunnelHelpers(IsolatedAsyncioTestCase):
|
||||
# Test some invalid tickets
|
||||
# Valid ticket are consts.TICKET_LENGTH bytes long, and must be A-Z, a-z, 0-9
|
||||
with mock.patch(
|
||||
'uds_tunnel.tunnel.TunnelProtocol._readFromUDS',
|
||||
'uds_tunnel.tunnel.TunnelProtocol._read_from_uds',
|
||||
new_callable=tools.AsyncMock,
|
||||
) as m:
|
||||
m.side_effect = uds_response
|
||||
@ -77,7 +77,7 @@ class TestTunnelHelpers(IsolatedAsyncioTestCase):
|
||||
)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
await tunnel.TunnelProtocol.getTicketFromUDS(
|
||||
await tunnel.TunnelProtocol.get_ticket_from_uds(
|
||||
cfg, ticket.encode(), conf.CALLER_HOST
|
||||
)
|
||||
|
||||
@ -85,7 +85,7 @@ class TestTunnelHelpers(IsolatedAsyncioTestCase):
|
||||
for i in range(0, 100):
|
||||
# Now some requests with valid tickets
|
||||
# Ensure no exception is raised
|
||||
ret_value = await tunnel.TunnelProtocol.getTicketFromUDS(
|
||||
ret_value = await tunnel.TunnelProtocol.get_ticket_from_uds(
|
||||
cfg, ticket.encode(), conf.CALLER_HOST
|
||||
)
|
||||
# Ensure data returned is correct {host, port, notify} from mock
|
||||
@ -108,7 +108,7 @@ class TestTunnelHelpers(IsolatedAsyncioTestCase):
|
||||
async def test_notify_end_to_uds_broker(self) -> None:
|
||||
_, cfg = fixtures.get_config()
|
||||
with mock.patch(
|
||||
'uds_tunnel.tunnel.TunnelProtocol._readFromUDS',
|
||||
'uds_tunnel.tunnel.TunnelProtocol._read_from_uds',
|
||||
new_callable=tools.AsyncMock,
|
||||
) as m:
|
||||
m.side_effect = uds_response
|
||||
@ -118,7 +118,7 @@ class TestTunnelHelpers(IsolatedAsyncioTestCase):
|
||||
|
||||
ticket = conf.NOTIFY_TICKET.encode()
|
||||
for i in range(0, 100):
|
||||
await tunnel.TunnelProtocol.notifyEndToUds(cfg, ticket, counter)
|
||||
await tunnel.TunnelProtocol.notify_end_to_uds(cfg, ticket, counter)
|
||||
|
||||
self.assertEqual(m.call_args[0][0], cfg)
|
||||
self.assertEqual(
|
||||
@ -159,9 +159,9 @@ class TestTunnelHelpers(IsolatedAsyncioTestCase):
|
||||
await tools.get(fake_uds_server),
|
||||
'{"result":"ok"}',
|
||||
)
|
||||
# Now, tests _readFromUDS
|
||||
# Now, tests _read_from_uds
|
||||
for i in range(100):
|
||||
ret = await tunnel.TunnelProtocol._readFromUDS(
|
||||
ret = await tunnel.TunnelProtocol._read_from_uds(
|
||||
cfg, conf.NOTIFY_TICKET.encode(), 'test', {'param': 'value'}
|
||||
)
|
||||
self.assertEqual(ret, {'result': 'ok'})
|
||||
|
@ -55,9 +55,9 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
|
||||
# Remote is not really important in this tests, will fail before using it
|
||||
async with tuntools.create_tunnel_proc(
|
||||
host,
|
||||
7777,
|
||||
7890, # A port not used by any other test
|
||||
'127.0.0.1',
|
||||
12345,
|
||||
13579, # A port not used by any other test
|
||||
) as cfg:
|
||||
for i in range(0, 8192, 128):
|
||||
# Set timeout to 1 seconds
|
||||
@ -85,9 +85,9 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
|
||||
# Remote is not really important in this tests, will return ok before using it (this is a TEST command, not OPEN)
|
||||
async with tuntools.create_tunnel_proc(
|
||||
host,
|
||||
7777,
|
||||
7891,
|
||||
'127.0.0.1',
|
||||
12345,
|
||||
13581,
|
||||
) as cfg:
|
||||
for i in range(10): # Several times
|
||||
# On full, we need the handshake to be done, before connecting
|
||||
@ -117,10 +117,7 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
|
||||
0, consts.TICKET_LENGTH - 1, 4
|
||||
): # All will fail. Any longer will be processed, and mock will return correct don't matter the ticket
|
||||
# Ticket must contain only letters and numbers
|
||||
ticket = ''.join(
|
||||
random.choice(string.ascii_letters + string.digits)
|
||||
for _ in range(i)
|
||||
).encode()
|
||||
ticket = tuntools.get_correct_ticket(i)
|
||||
# On full, we need the handshake to be done, before connecting
|
||||
# Our "test" server will simple "eat" the handshake, but we need to do it
|
||||
async with tuntools.open_tunnel_client(
|
||||
@ -161,10 +158,7 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
|
||||
) as cfg:
|
||||
for i in range(1):
|
||||
# Create a random ticket with valid format
|
||||
ticket = ''.join(
|
||||
random.choice(string.ascii_letters + string.digits)
|
||||
for _ in range(consts.TICKET_LENGTH)
|
||||
).encode()
|
||||
ticket = tuntools.get_correct_ticket()
|
||||
# On full, we need the handshake to be done, before connecting
|
||||
# Our "test" server will simple "eat" the handshake, but we need to do it
|
||||
async with tuntools.open_tunnel_client(
|
||||
@ -175,21 +169,53 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
|
||||
cwriter.write(ticket)
|
||||
|
||||
await cwriter.drain()
|
||||
# Read response
|
||||
# Read response, should be ok
|
||||
data = await creader.read(1024)
|
||||
self.assertEqual(data, consts.RESPONSE_OK, f'Tunnel host: {tunnel_host}, server host: {host}')
|
||||
self.assertEqual(
|
||||
data,
|
||||
consts.RESPONSE_OK,
|
||||
f'Tunnel host: {tunnel_host}, server host: {host}',
|
||||
)
|
||||
|
||||
# Data sent will be received by server
|
||||
# One single write will ensure all data is on same packet
|
||||
test_str = b'Some Random Data' + bytes(random.randint(0, 255) for _ in range(8192)) + b'STREAM_END'
|
||||
test_str = (
|
||||
b'Some Random Data'
|
||||
+ bytes(random.randint(0, 255) for _ in range(8192))
|
||||
+ b'STREAM_END'
|
||||
)
|
||||
# Clean received data
|
||||
received = b''
|
||||
# And reset event
|
||||
callback_invoked.clear()
|
||||
|
||||
|
||||
cwriter.write(test_str)
|
||||
await cwriter.drain()
|
||||
|
||||
# Wait for callback to be invoked
|
||||
await callback_invoked.wait()
|
||||
self.assertEqual(received, test_str)
|
||||
|
||||
async def test_tunnel_no_remote(self) -> None:
|
||||
for host in ('127.0.0.1', '::1'):
|
||||
for tunnel_host in ('127.0.0.1', '::1'):
|
||||
async with tuntools.create_tunnel_proc(
|
||||
tunnel_host,
|
||||
7888,
|
||||
host,
|
||||
17222, # Any non used port will do the trick
|
||||
) as cfg:
|
||||
ticket = tuntools.get_correct_ticket()
|
||||
# On full, we need the handshake to be done, before connecting
|
||||
# Our "test" server will simple "eat" the handshake, but we need to do it
|
||||
async with tuntools.open_tunnel_client(
|
||||
cfg, use_tunnel_handshake=True
|
||||
) as (creader, cwriter):
|
||||
cwriter.write(consts.COMMAND_OPEN)
|
||||
# fake ticket, consts.TICKET_LENGTH bytes long, letters and numbers. Use a random ticket,
|
||||
cwriter.write(ticket)
|
||||
|
||||
await cwriter.drain()
|
||||
# Read response
|
||||
data = await creader.read(1024)
|
||||
self.assertEqual(data, b'', f'Tunnel host: {tunnel_host}, server host: {host}')
|
||||
|
@ -38,7 +38,6 @@ from unittest import mock
|
||||
|
||||
from . import certs
|
||||
|
||||
|
||||
class AsyncMock(mock.MagicMock):
|
||||
async def __call__(self, *args, **kwargs):
|
||||
return super().__call__(*args, **kwargs)
|
||||
@ -111,7 +110,8 @@ class AsyncTCPServer:
|
||||
port: int
|
||||
_server: typing.Optional[asyncio.AbstractServer]
|
||||
_response: typing.Optional[bytes]
|
||||
_callback: typing.Optional[typing.Callable[[bytes], None]]
|
||||
_callback: typing.Optional[typing.Callable[[bytes], typing.Optional[bytes]]]
|
||||
_processor: typing.Optional[typing.Callable[[asyncio.StreamReader, asyncio.StreamWriter], None]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -119,26 +119,30 @@ class AsyncTCPServer:
|
||||
*,
|
||||
response: typing.Optional[bytes] = None,
|
||||
host: str = '127.0.0.1', # ip
|
||||
callback: typing.Optional[typing.Callable[[bytes], None]] = None,
|
||||
callback: typing.Optional[typing.Callable[[bytes], typing.Optional[bytes]]] = None,
|
||||
processor: typing.Optional[typing.Callable[[asyncio.StreamReader, asyncio.StreamWriter], None]] = None,
|
||||
) -> None:
|
||||
self.host = host
|
||||
self.port = port
|
||||
self._server = None
|
||||
self._response = response
|
||||
self._callback = callback
|
||||
self._processor = processor
|
||||
|
||||
self.data = b''
|
||||
|
||||
async def _handle(self, reader, writer) -> None:
|
||||
if self._processor is not None:
|
||||
self._processor(reader, writer)
|
||||
return
|
||||
while True:
|
||||
data = await reader.read(2048)
|
||||
if not data:
|
||||
break
|
||||
|
||||
if self._callback:
|
||||
self._callback(data)
|
||||
resp = self._callback(data) if self._callback else self._response
|
||||
|
||||
if self._response is not None:
|
||||
if resp is not None:
|
||||
data = self._response
|
||||
writer.write(data)
|
||||
await writer.drain()
|
||||
|
@ -30,22 +30,23 @@ Author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
'''
|
||||
import asyncio
|
||||
import contextlib
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import random
|
||||
import socket
|
||||
import ssl
|
||||
import os
|
||||
import typing
|
||||
import string
|
||||
import tempfile
|
||||
import typing
|
||||
from unittest import mock
|
||||
import multiprocessing
|
||||
|
||||
import udstunnel
|
||||
from uds_tunnel import consts, tunnel, stats, config
|
||||
from uds_tunnel import config, consts, stats, tunnel
|
||||
|
||||
from . import certs, conf, fixtures, tools
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
@ -100,7 +101,7 @@ async def create_tunnel_proc(
|
||||
remote_host: str,
|
||||
remote_port: int,
|
||||
*,
|
||||
response: typing.Optional[typing.Mapping[str, typing.Any]] = None
|
||||
response: typing.Optional[typing.Mapping[str, typing.Any]] = None,
|
||||
) -> typing.AsyncGenerator['config.ConfigurationType', None]:
|
||||
with create_config_file(listen_host, listen_port) as cfgfile:
|
||||
args = mock.MagicMock()
|
||||
@ -116,7 +117,7 @@ async def create_tunnel_proc(
|
||||
response = conf.UDS_GET_TICKET_RESPONSE(remote_host, remote_port)
|
||||
|
||||
with mock.patch(
|
||||
'uds_tunnel.tunnel.TunnelProtocol._readFromUDS',
|
||||
'uds_tunnel.tunnel.TunnelProtocol._read_from_uds',
|
||||
new_callable=tools.AsyncMock,
|
||||
) as m:
|
||||
m.return_value = response
|
||||
@ -214,21 +215,26 @@ async def create_tunnel_server(
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def create_test_tunnel(
|
||||
*, callback: typing.Callable[[bytes], None]
|
||||
*,
|
||||
callback: typing.Callable[[bytes], None],
|
||||
port: typing.Optional[int] = None,
|
||||
remote_port: typing.Optional[int] = None,
|
||||
) -> typing.AsyncGenerator['config.ConfigurationType', None]:
|
||||
# Generate a listening server for testing tunnel
|
||||
# Prepare the end of the tunnel
|
||||
async with tools.AsyncTCPServer(port=54876, callback=callback) as server:
|
||||
async with tools.AsyncTCPServer(
|
||||
port=remote_port or 54876, callback=callback
|
||||
) as server:
|
||||
# Create a tunnel to localhost 13579
|
||||
# SSl cert for tunnel server
|
||||
with certs.ssl_context(server.host) as (ssl_ctx, _):
|
||||
_, cfg = fixtures.get_config(
|
||||
address=server.host,
|
||||
port=7777,
|
||||
port=port or 7777,
|
||||
ipv6=':' in server.host,
|
||||
)
|
||||
with mock.patch(
|
||||
'uds_tunnel.tunnel.TunnelProtocol._readFromUDS',
|
||||
'uds_tunnel.tunnel.TunnelProtocol._read_from_uds',
|
||||
new_callable=tools.AsyncMock,
|
||||
) as m:
|
||||
m.return_value = conf.UDS_GET_TICKET_RESPONSE(server.host, server.port)
|
||||
@ -241,6 +247,26 @@ async def create_test_tunnel(
|
||||
await tunnel_server.wait_closed()
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def create_fake_broker_server(
|
||||
response: typing.Mapping[str, typing.Any], port: int = 44443
|
||||
) -> typing.AsyncGenerator[None, None]:
|
||||
# crate a fake broker server
|
||||
# Ignores request, and sends response
|
||||
|
||||
resp: bytes = b'HTTP/1.1 200 OK\r\n\r\n' + json.dumps(response).encode()
|
||||
|
||||
def callback(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
|
||||
writer.write(resp)
|
||||
writer.close()
|
||||
|
||||
async with tools.AsyncTCPServer(port=port, response=resp) as server:
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
pass # nothing to do
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def open_tunnel_client(
|
||||
cfg: 'config.ConfigurationType',
|
||||
@ -284,7 +310,7 @@ async def tunnel_app_runner(
|
||||
host: typing.Optional[str] = None,
|
||||
port: typing.Optional[int] = None,
|
||||
*,
|
||||
args: typing.Optional[typing.List[str]] = None
|
||||
args: typing.Optional[typing.List[str]] = None,
|
||||
) -> typing.AsyncGenerator['Process', None]:
|
||||
# Ensure we are on src directory
|
||||
if os.path.basename(os.getcwd()) != 'src':
|
||||
@ -312,3 +338,10 @@ async def tunnel_app_runner(
|
||||
if process.returncode is None:
|
||||
process.terminate()
|
||||
await process.wait()
|
||||
|
||||
|
||||
def get_correct_ticket(length: int = consts.TICKET_LENGTH) -> bytes:
|
||||
return ''.join(
|
||||
random.choice(string.ascii_letters + string.digits)
|
||||
for _ in range(length)
|
||||
).encode()
|
||||
|
Loading…
Reference in New Issue
Block a user