mirror of
https://github.com/dkmstr/openuds.git
synced 2024-12-22 13:34:04 +03:00
Small fixes and more tests. Refactoring.
This commit is contained in:
parent
fdd8ac00c9
commit
472299f878
2
tunnel-server/.gitignore
vendored
2
tunnel-server/.gitignore
vendored
@ -0,0 +1,2 @@
|
|||||||
|
htmlcov
|
||||||
|
.coverage
|
@ -1,5 +1,8 @@
|
|||||||
[run]
|
[run]
|
||||||
dynamic_context = test_function
|
dynamic_context = test_function
|
||||||
|
omit =
|
||||||
|
test/*
|
||||||
|
/etc/*
|
||||||
|
|
||||||
[report]
|
[report]
|
||||||
skip_empty = True
|
skip_empty = True
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
[pytest]
|
[pytest]
|
||||||
addopts = "-s"
|
#addopts = "-s"
|
||||||
|
addopts = --cov --cov-report html --cov-config=coverage.ini -n 12
|
||||||
pythonpath = ./src
|
pythonpath = ./src
|
||||||
python_files = tests.py test_*.py *_tests.py
|
python_files = tests.py test_*.py *_tests.py
|
||||||
log_format = "%(asctime)s %(levelname)s %(message)s"
|
log_format = "%(asctime)s %(levelname)s %(message)s"
|
||||||
|
@ -114,7 +114,7 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
|
|
||||||
async def open_other_side() -> None:
|
async def open_other_side() -> None:
|
||||||
try:
|
try:
|
||||||
result = await TunnelProtocol.getTicketFromUDS(
|
result = await TunnelProtocol.get_ticket_from_uds(
|
||||||
self.owner.cfg, ticket, self.source
|
self.owner.cfg, ticket, self.source
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -277,10 +277,10 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
logger.debug('Data received: %s', len(data))
|
logger.debug('Data received: %s', len(data))
|
||||||
self.runner(data) # send data to current runner (command or proxy)
|
self.runner(data) # send data to current runner (command or proxy)
|
||||||
|
|
||||||
def notifyEnd(self):
|
def notify_end(self):
|
||||||
if self.notify_ticket:
|
if self.notify_ticket:
|
||||||
asyncio.get_event_loop().create_task(
|
asyncio.get_event_loop().create_task(
|
||||||
TunnelProtocol.notifyEndToUds(
|
TunnelProtocol.notify_end_to_uds(
|
||||||
self.owner.cfg, self.notify_ticket, self.stats_manager
|
self.owner.cfg, self.notify_ticket, self.stats_manager
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -295,7 +295,7 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
self.other_side.transport.close()
|
self.other_side.transport.close()
|
||||||
else:
|
else:
|
||||||
self.stats_manager.close()
|
self.stats_manager.close()
|
||||||
self.notifyEnd()
|
self.notify_end()
|
||||||
|
|
||||||
# helpers
|
# helpers
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -324,12 +324,12 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
int(self.stats_manager.end - self.stats_manager.start),
|
int(self.stats_manager.end - self.stats_manager.start),
|
||||||
)
|
)
|
||||||
# Notify end to uds
|
# Notify end to uds
|
||||||
self.notifyEnd()
|
self.notify_end()
|
||||||
else:
|
else:
|
||||||
logger.info('TERMINATED %s', self.pretty_source())
|
logger.info('TERMINATED %s', self.pretty_source())
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _readFromUDS(
|
async def _read_from_uds(
|
||||||
cfg: config.ConfigurationType,
|
cfg: config.ConfigurationType,
|
||||||
ticket: bytes,
|
ticket: bytes,
|
||||||
msg: str,
|
msg: str,
|
||||||
@ -358,7 +358,7 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
raise Exception(f'TICKET COMMS ERROR: {ticket.decode()} {msg} {e!s}')
|
raise Exception(f'TICKET COMMS ERROR: {ticket.decode()} {msg} {e!s}')
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def getTicketFromUDS(
|
async def get_ticket_from_uds(
|
||||||
cfg: config.ConfigurationType, ticket: bytes, address: typing.Tuple[str, int]
|
cfg: config.ConfigurationType, ticket: bytes, address: typing.Tuple[str, int]
|
||||||
) -> typing.MutableMapping[str, typing.Any]:
|
) -> typing.MutableMapping[str, typing.Any]:
|
||||||
# Sanity checks
|
# Sanity checks
|
||||||
@ -374,13 +374,13 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
continue # Correctus
|
continue # Correctus
|
||||||
raise ValueError(f'TICKET INVALID (char {i} at pos {n})')
|
raise ValueError(f'TICKET INVALID (char {i} at pos {n})')
|
||||||
|
|
||||||
return await TunnelProtocol._readFromUDS(cfg, ticket, address[0])
|
return await TunnelProtocol._read_from_uds(cfg, ticket, address[0])
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def notifyEndToUds(
|
async def notify_end_to_uds(
|
||||||
cfg: config.ConfigurationType, ticket: bytes, counter: stats.Stats
|
cfg: config.ConfigurationType, ticket: bytes, counter: stats.Stats
|
||||||
) -> None:
|
) -> None:
|
||||||
await TunnelProtocol._readFromUDS(
|
await TunnelProtocol._read_from_uds(
|
||||||
cfg,
|
cfg,
|
||||||
ticket,
|
ticket,
|
||||||
'stop',
|
'stop',
|
||||||
|
@ -30,7 +30,8 @@ Author: Adolfo Gómez, dkmaster at dkmon dot com
|
|||||||
'''
|
'''
|
||||||
import typing
|
import typing
|
||||||
import asyncio
|
import asyncio
|
||||||
import io
|
import random
|
||||||
|
import string
|
||||||
import logging
|
import logging
|
||||||
from unittest import IsolatedAsyncioTestCase, mock
|
from unittest import IsolatedAsyncioTestCase, mock
|
||||||
|
|
||||||
@ -38,8 +39,12 @@ from uds_tunnel import consts
|
|||||||
|
|
||||||
from .utils import tuntools, tools
|
from .utils import tuntools, tools
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
from uds_tunnel import config
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TestUDSTunnelApp(IsolatedAsyncioTestCase):
|
class TestUDSTunnelApp(IsolatedAsyncioTestCase):
|
||||||
async def test_run_app_help(self) -> None:
|
async def test_run_app_help(self) -> None:
|
||||||
# Executes the app with --help
|
# Executes the app with --help
|
||||||
@ -49,10 +54,49 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
|
|||||||
self.assertEqual(process.returncode, 0, f'{stdout!r} {stderr!r}')
|
self.assertEqual(process.returncode, 0, f'{stdout!r} {stderr!r}')
|
||||||
self.assertEqual(stderr, b'')
|
self.assertEqual(stderr, b'')
|
||||||
self.assertIn(b'usage: udstunnel', stdout)
|
self.assertIn(b'usage: udstunnel', stdout)
|
||||||
|
|
||||||
|
async def client_task(
|
||||||
|
self, cfg: 'config.ConfigurationType', host: str, port: int
|
||||||
|
) -> None:
|
||||||
|
received: bytes = b''
|
||||||
|
callback_invoked: asyncio.Event = asyncio.Event()
|
||||||
|
|
||||||
|
def callback(data: bytes) -> None:
|
||||||
|
nonlocal received
|
||||||
|
received += data
|
||||||
|
# if data contains EOS marcker ('STREAM_END'), we are done
|
||||||
|
if b'STREAM_END' in data:
|
||||||
|
callback_invoked.set()
|
||||||
|
|
||||||
|
async with tools.AsyncTCPServer(
|
||||||
|
host=host, port=5445, callback=callback
|
||||||
|
) as server:
|
||||||
|
# Create a random ticket with valid format
|
||||||
|
ticket = ''.join(
|
||||||
|
random.choice(string.ascii_letters + string.digits)
|
||||||
|
for _ in range(consts.TICKET_LENGTH)
|
||||||
|
).encode()
|
||||||
|
# Open and send handshake
|
||||||
|
async with tuntools.open_tunnel_client(cfg, use_tunnel_handshake=True) as (
|
||||||
|
creader,
|
||||||
|
cwriter,
|
||||||
|
):
|
||||||
|
# Now open command with ticket
|
||||||
|
cwriter.write(consts.COMMAND_OPEN)
|
||||||
|
# fake ticket, consts.TICKET_LENGTH bytes long, letters and numbers. Use a random ticket,
|
||||||
|
cwriter.write(ticket)
|
||||||
|
|
||||||
|
await cwriter.drain()
|
||||||
|
# Read response, should be ok
|
||||||
|
data = await creader.read(1024)
|
||||||
|
self.assertEqual(
|
||||||
|
data,
|
||||||
|
consts.RESPONSE_OK,
|
||||||
|
f'Server host: {host}:{port} - Ticket: {ticket!r} - Response: {data!r}',
|
||||||
|
)
|
||||||
|
|
||||||
async def test_run_app_serve(self) -> None:
|
async def test_run_app_serve(self) -> None:
|
||||||
for host in ('127.0.0.1', '::1'):
|
for host in ('127.0.0.1', '::1'):
|
||||||
async with tuntools.tunnel_app_runner(host, 7777) as process:
|
async with tuntools.tunnel_app_runner(host, 7777) as process:
|
||||||
print(process)
|
# Create a "bunch" of servers and clients
|
||||||
|
pass
|
||||||
|
@ -66,7 +66,7 @@ class TestTunnel(IsolatedAsyncioTestCase):
|
|||||||
# Set timeout to 1 seconds
|
# Set timeout to 1 seconds
|
||||||
bad_cmd = bytes(random.randint(0, 255) for _ in range(i)) # Some garbage
|
bad_cmd = bytes(random.randint(0, 255) for _ in range(i)) # Some garbage
|
||||||
logger.info(f'Testing invalid command with {bad_cmd!r}')
|
logger.info(f'Testing invalid command with {bad_cmd!r}')
|
||||||
async with tuntools.create_test_tunnel(callback=lambda x: None) as cfg:
|
async with tuntools.create_test_tunnel(callback=lambda x: None, port=7770, remote_port=54555) as cfg:
|
||||||
logger_mock = mock.MagicMock()
|
logger_mock = mock.MagicMock()
|
||||||
with mock.patch('uds_tunnel.tunnel.logger', logger_mock):
|
with mock.patch('uds_tunnel.tunnel.logger', logger_mock):
|
||||||
# Open connection to tunnel
|
# Open connection to tunnel
|
||||||
|
@ -65,7 +65,7 @@ class TestTunnelHelpers(IsolatedAsyncioTestCase):
|
|||||||
# Test some invalid tickets
|
# Test some invalid tickets
|
||||||
# Valid ticket are consts.TICKET_LENGTH bytes long, and must be A-Z, a-z, 0-9
|
# Valid ticket are consts.TICKET_LENGTH bytes long, and must be A-Z, a-z, 0-9
|
||||||
with mock.patch(
|
with mock.patch(
|
||||||
'uds_tunnel.tunnel.TunnelProtocol._readFromUDS',
|
'uds_tunnel.tunnel.TunnelProtocol._read_from_uds',
|
||||||
new_callable=tools.AsyncMock,
|
new_callable=tools.AsyncMock,
|
||||||
) as m:
|
) as m:
|
||||||
m.side_effect = uds_response
|
m.side_effect = uds_response
|
||||||
@ -77,7 +77,7 @@ class TestTunnelHelpers(IsolatedAsyncioTestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
await tunnel.TunnelProtocol.getTicketFromUDS(
|
await tunnel.TunnelProtocol.get_ticket_from_uds(
|
||||||
cfg, ticket.encode(), conf.CALLER_HOST
|
cfg, ticket.encode(), conf.CALLER_HOST
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -85,7 +85,7 @@ class TestTunnelHelpers(IsolatedAsyncioTestCase):
|
|||||||
for i in range(0, 100):
|
for i in range(0, 100):
|
||||||
# Now some requests with valid tickets
|
# Now some requests with valid tickets
|
||||||
# Ensure no exception is raised
|
# Ensure no exception is raised
|
||||||
ret_value = await tunnel.TunnelProtocol.getTicketFromUDS(
|
ret_value = await tunnel.TunnelProtocol.get_ticket_from_uds(
|
||||||
cfg, ticket.encode(), conf.CALLER_HOST
|
cfg, ticket.encode(), conf.CALLER_HOST
|
||||||
)
|
)
|
||||||
# Ensure data returned is correct {host, port, notify} from mock
|
# Ensure data returned is correct {host, port, notify} from mock
|
||||||
@ -108,7 +108,7 @@ class TestTunnelHelpers(IsolatedAsyncioTestCase):
|
|||||||
async def test_notify_end_to_uds_broker(self) -> None:
|
async def test_notify_end_to_uds_broker(self) -> None:
|
||||||
_, cfg = fixtures.get_config()
|
_, cfg = fixtures.get_config()
|
||||||
with mock.patch(
|
with mock.patch(
|
||||||
'uds_tunnel.tunnel.TunnelProtocol._readFromUDS',
|
'uds_tunnel.tunnel.TunnelProtocol._read_from_uds',
|
||||||
new_callable=tools.AsyncMock,
|
new_callable=tools.AsyncMock,
|
||||||
) as m:
|
) as m:
|
||||||
m.side_effect = uds_response
|
m.side_effect = uds_response
|
||||||
@ -118,7 +118,7 @@ class TestTunnelHelpers(IsolatedAsyncioTestCase):
|
|||||||
|
|
||||||
ticket = conf.NOTIFY_TICKET.encode()
|
ticket = conf.NOTIFY_TICKET.encode()
|
||||||
for i in range(0, 100):
|
for i in range(0, 100):
|
||||||
await tunnel.TunnelProtocol.notifyEndToUds(cfg, ticket, counter)
|
await tunnel.TunnelProtocol.notify_end_to_uds(cfg, ticket, counter)
|
||||||
|
|
||||||
self.assertEqual(m.call_args[0][0], cfg)
|
self.assertEqual(m.call_args[0][0], cfg)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
@ -159,9 +159,9 @@ class TestTunnelHelpers(IsolatedAsyncioTestCase):
|
|||||||
await tools.get(fake_uds_server),
|
await tools.get(fake_uds_server),
|
||||||
'{"result":"ok"}',
|
'{"result":"ok"}',
|
||||||
)
|
)
|
||||||
# Now, tests _readFromUDS
|
# Now, tests _read_from_uds
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
ret = await tunnel.TunnelProtocol._readFromUDS(
|
ret = await tunnel.TunnelProtocol._read_from_uds(
|
||||||
cfg, conf.NOTIFY_TICKET.encode(), 'test', {'param': 'value'}
|
cfg, conf.NOTIFY_TICKET.encode(), 'test', {'param': 'value'}
|
||||||
)
|
)
|
||||||
self.assertEqual(ret, {'result': 'ok'})
|
self.assertEqual(ret, {'result': 'ok'})
|
||||||
|
@ -55,9 +55,9 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
|
|||||||
# Remote is not really important in this tests, will fail before using it
|
# Remote is not really important in this tests, will fail before using it
|
||||||
async with tuntools.create_tunnel_proc(
|
async with tuntools.create_tunnel_proc(
|
||||||
host,
|
host,
|
||||||
7777,
|
7890, # A port not used by any other test
|
||||||
'127.0.0.1',
|
'127.0.0.1',
|
||||||
12345,
|
13579, # A port not used by any other test
|
||||||
) as cfg:
|
) as cfg:
|
||||||
for i in range(0, 8192, 128):
|
for i in range(0, 8192, 128):
|
||||||
# Set timeout to 1 seconds
|
# Set timeout to 1 seconds
|
||||||
@ -85,9 +85,9 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
|
|||||||
# Remote is not really important in this tests, will return ok before using it (this is a TEST command, not OPEN)
|
# Remote is not really important in this tests, will return ok before using it (this is a TEST command, not OPEN)
|
||||||
async with tuntools.create_tunnel_proc(
|
async with tuntools.create_tunnel_proc(
|
||||||
host,
|
host,
|
||||||
7777,
|
7891,
|
||||||
'127.0.0.1',
|
'127.0.0.1',
|
||||||
12345,
|
13581,
|
||||||
) as cfg:
|
) as cfg:
|
||||||
for i in range(10): # Several times
|
for i in range(10): # Several times
|
||||||
# On full, we need the handshake to be done, before connecting
|
# On full, we need the handshake to be done, before connecting
|
||||||
@ -117,10 +117,7 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
|
|||||||
0, consts.TICKET_LENGTH - 1, 4
|
0, consts.TICKET_LENGTH - 1, 4
|
||||||
): # All will fail. Any longer will be processed, and mock will return correct don't matter the ticket
|
): # All will fail. Any longer will be processed, and mock will return correct don't matter the ticket
|
||||||
# Ticket must contain only letters and numbers
|
# Ticket must contain only letters and numbers
|
||||||
ticket = ''.join(
|
ticket = tuntools.get_correct_ticket(i)
|
||||||
random.choice(string.ascii_letters + string.digits)
|
|
||||||
for _ in range(i)
|
|
||||||
).encode()
|
|
||||||
# On full, we need the handshake to be done, before connecting
|
# On full, we need the handshake to be done, before connecting
|
||||||
# Our "test" server will simple "eat" the handshake, but we need to do it
|
# Our "test" server will simple "eat" the handshake, but we need to do it
|
||||||
async with tuntools.open_tunnel_client(
|
async with tuntools.open_tunnel_client(
|
||||||
@ -161,10 +158,7 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
|
|||||||
) as cfg:
|
) as cfg:
|
||||||
for i in range(1):
|
for i in range(1):
|
||||||
# Create a random ticket with valid format
|
# Create a random ticket with valid format
|
||||||
ticket = ''.join(
|
ticket = tuntools.get_correct_ticket()
|
||||||
random.choice(string.ascii_letters + string.digits)
|
|
||||||
for _ in range(consts.TICKET_LENGTH)
|
|
||||||
).encode()
|
|
||||||
# On full, we need the handshake to be done, before connecting
|
# On full, we need the handshake to be done, before connecting
|
||||||
# Our "test" server will simple "eat" the handshake, but we need to do it
|
# Our "test" server will simple "eat" the handshake, but we need to do it
|
||||||
async with tuntools.open_tunnel_client(
|
async with tuntools.open_tunnel_client(
|
||||||
@ -175,21 +169,53 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
|
|||||||
cwriter.write(ticket)
|
cwriter.write(ticket)
|
||||||
|
|
||||||
await cwriter.drain()
|
await cwriter.drain()
|
||||||
# Read response
|
# Read response, should be ok
|
||||||
data = await creader.read(1024)
|
data = await creader.read(1024)
|
||||||
self.assertEqual(data, consts.RESPONSE_OK, f'Tunnel host: {tunnel_host}, server host: {host}')
|
self.assertEqual(
|
||||||
|
data,
|
||||||
|
consts.RESPONSE_OK,
|
||||||
|
f'Tunnel host: {tunnel_host}, server host: {host}',
|
||||||
|
)
|
||||||
|
|
||||||
# Data sent will be received by server
|
# Data sent will be received by server
|
||||||
# One single write will ensure all data is on same packet
|
# One single write will ensure all data is on same packet
|
||||||
test_str = b'Some Random Data' + bytes(random.randint(0, 255) for _ in range(8192)) + b'STREAM_END'
|
test_str = (
|
||||||
|
b'Some Random Data'
|
||||||
|
+ bytes(random.randint(0, 255) for _ in range(8192))
|
||||||
|
+ b'STREAM_END'
|
||||||
|
)
|
||||||
# Clean received data
|
# Clean received data
|
||||||
received = b''
|
received = b''
|
||||||
# And reset event
|
# And reset event
|
||||||
callback_invoked.clear()
|
callback_invoked.clear()
|
||||||
|
|
||||||
cwriter.write(test_str)
|
cwriter.write(test_str)
|
||||||
await cwriter.drain()
|
await cwriter.drain()
|
||||||
|
|
||||||
# Wait for callback to be invoked
|
# Wait for callback to be invoked
|
||||||
await callback_invoked.wait()
|
await callback_invoked.wait()
|
||||||
self.assertEqual(received, test_str)
|
self.assertEqual(received, test_str)
|
||||||
|
|
||||||
|
async def test_tunnel_no_remote(self) -> None:
|
||||||
|
for host in ('127.0.0.1', '::1'):
|
||||||
|
for tunnel_host in ('127.0.0.1', '::1'):
|
||||||
|
async with tuntools.create_tunnel_proc(
|
||||||
|
tunnel_host,
|
||||||
|
7888,
|
||||||
|
host,
|
||||||
|
17222, # Any non used port will do the trick
|
||||||
|
) as cfg:
|
||||||
|
ticket = tuntools.get_correct_ticket()
|
||||||
|
# On full, we need the handshake to be done, before connecting
|
||||||
|
# Our "test" server will simple "eat" the handshake, but we need to do it
|
||||||
|
async with tuntools.open_tunnel_client(
|
||||||
|
cfg, use_tunnel_handshake=True
|
||||||
|
) as (creader, cwriter):
|
||||||
|
cwriter.write(consts.COMMAND_OPEN)
|
||||||
|
# fake ticket, consts.TICKET_LENGTH bytes long, letters and numbers. Use a random ticket,
|
||||||
|
cwriter.write(ticket)
|
||||||
|
|
||||||
|
await cwriter.drain()
|
||||||
|
# Read response
|
||||||
|
data = await creader.read(1024)
|
||||||
|
self.assertEqual(data, b'', f'Tunnel host: {tunnel_host}, server host: {host}')
|
||||||
|
@ -38,7 +38,6 @@ from unittest import mock
|
|||||||
|
|
||||||
from . import certs
|
from . import certs
|
||||||
|
|
||||||
|
|
||||||
class AsyncMock(mock.MagicMock):
|
class AsyncMock(mock.MagicMock):
|
||||||
async def __call__(self, *args, **kwargs):
|
async def __call__(self, *args, **kwargs):
|
||||||
return super().__call__(*args, **kwargs)
|
return super().__call__(*args, **kwargs)
|
||||||
@ -111,7 +110,8 @@ class AsyncTCPServer:
|
|||||||
port: int
|
port: int
|
||||||
_server: typing.Optional[asyncio.AbstractServer]
|
_server: typing.Optional[asyncio.AbstractServer]
|
||||||
_response: typing.Optional[bytes]
|
_response: typing.Optional[bytes]
|
||||||
_callback: typing.Optional[typing.Callable[[bytes], None]]
|
_callback: typing.Optional[typing.Callable[[bytes], typing.Optional[bytes]]]
|
||||||
|
_processor: typing.Optional[typing.Callable[[asyncio.StreamReader, asyncio.StreamWriter], None]]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -119,26 +119,30 @@ class AsyncTCPServer:
|
|||||||
*,
|
*,
|
||||||
response: typing.Optional[bytes] = None,
|
response: typing.Optional[bytes] = None,
|
||||||
host: str = '127.0.0.1', # ip
|
host: str = '127.0.0.1', # ip
|
||||||
callback: typing.Optional[typing.Callable[[bytes], None]] = None,
|
callback: typing.Optional[typing.Callable[[bytes], typing.Optional[bytes]]] = None,
|
||||||
|
processor: typing.Optional[typing.Callable[[asyncio.StreamReader, asyncio.StreamWriter], None]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.host = host
|
self.host = host
|
||||||
self.port = port
|
self.port = port
|
||||||
self._server = None
|
self._server = None
|
||||||
self._response = response
|
self._response = response
|
||||||
self._callback = callback
|
self._callback = callback
|
||||||
|
self._processor = processor
|
||||||
|
|
||||||
self.data = b''
|
self.data = b''
|
||||||
|
|
||||||
async def _handle(self, reader, writer) -> None:
|
async def _handle(self, reader, writer) -> None:
|
||||||
|
if self._processor is not None:
|
||||||
|
self._processor(reader, writer)
|
||||||
|
return
|
||||||
while True:
|
while True:
|
||||||
data = await reader.read(2048)
|
data = await reader.read(2048)
|
||||||
if not data:
|
if not data:
|
||||||
break
|
break
|
||||||
|
|
||||||
if self._callback:
|
resp = self._callback(data) if self._callback else self._response
|
||||||
self._callback(data)
|
|
||||||
|
|
||||||
if self._response is not None:
|
if resp is not None:
|
||||||
data = self._response
|
data = self._response
|
||||||
writer.write(data)
|
writer.write(data)
|
||||||
await writer.drain()
|
await writer.drain()
|
||||||
|
@ -30,22 +30,23 @@ Author: Adolfo Gómez, dkmaster at dkmon dot com
|
|||||||
'''
|
'''
|
||||||
import asyncio
|
import asyncio
|
||||||
import contextlib
|
import contextlib
|
||||||
import os
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import multiprocessing
|
||||||
|
import os
|
||||||
|
import random
|
||||||
import socket
|
import socket
|
||||||
import ssl
|
import ssl
|
||||||
import os
|
import string
|
||||||
import typing
|
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import typing
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
import multiprocessing
|
|
||||||
|
|
||||||
import udstunnel
|
import udstunnel
|
||||||
from uds_tunnel import consts, tunnel, stats, config
|
from uds_tunnel import config, consts, stats, tunnel
|
||||||
|
|
||||||
from . import certs, conf, fixtures, tools
|
from . import certs, conf, fixtures, tools
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
if typing.TYPE_CHECKING:
|
||||||
@ -100,7 +101,7 @@ async def create_tunnel_proc(
|
|||||||
remote_host: str,
|
remote_host: str,
|
||||||
remote_port: int,
|
remote_port: int,
|
||||||
*,
|
*,
|
||||||
response: typing.Optional[typing.Mapping[str, typing.Any]] = None
|
response: typing.Optional[typing.Mapping[str, typing.Any]] = None,
|
||||||
) -> typing.AsyncGenerator['config.ConfigurationType', None]:
|
) -> typing.AsyncGenerator['config.ConfigurationType', None]:
|
||||||
with create_config_file(listen_host, listen_port) as cfgfile:
|
with create_config_file(listen_host, listen_port) as cfgfile:
|
||||||
args = mock.MagicMock()
|
args = mock.MagicMock()
|
||||||
@ -116,7 +117,7 @@ async def create_tunnel_proc(
|
|||||||
response = conf.UDS_GET_TICKET_RESPONSE(remote_host, remote_port)
|
response = conf.UDS_GET_TICKET_RESPONSE(remote_host, remote_port)
|
||||||
|
|
||||||
with mock.patch(
|
with mock.patch(
|
||||||
'uds_tunnel.tunnel.TunnelProtocol._readFromUDS',
|
'uds_tunnel.tunnel.TunnelProtocol._read_from_uds',
|
||||||
new_callable=tools.AsyncMock,
|
new_callable=tools.AsyncMock,
|
||||||
) as m:
|
) as m:
|
||||||
m.return_value = response
|
m.return_value = response
|
||||||
@ -214,21 +215,26 @@ async def create_tunnel_server(
|
|||||||
|
|
||||||
@contextlib.asynccontextmanager
|
@contextlib.asynccontextmanager
|
||||||
async def create_test_tunnel(
|
async def create_test_tunnel(
|
||||||
*, callback: typing.Callable[[bytes], None]
|
*,
|
||||||
|
callback: typing.Callable[[bytes], None],
|
||||||
|
port: typing.Optional[int] = None,
|
||||||
|
remote_port: typing.Optional[int] = None,
|
||||||
) -> typing.AsyncGenerator['config.ConfigurationType', None]:
|
) -> typing.AsyncGenerator['config.ConfigurationType', None]:
|
||||||
# Generate a listening server for testing tunnel
|
# Generate a listening server for testing tunnel
|
||||||
# Prepare the end of the tunnel
|
# Prepare the end of the tunnel
|
||||||
async with tools.AsyncTCPServer(port=54876, callback=callback) as server:
|
async with tools.AsyncTCPServer(
|
||||||
|
port=remote_port or 54876, callback=callback
|
||||||
|
) as server:
|
||||||
# Create a tunnel to localhost 13579
|
# Create a tunnel to localhost 13579
|
||||||
# SSl cert for tunnel server
|
# SSl cert for tunnel server
|
||||||
with certs.ssl_context(server.host) as (ssl_ctx, _):
|
with certs.ssl_context(server.host) as (ssl_ctx, _):
|
||||||
_, cfg = fixtures.get_config(
|
_, cfg = fixtures.get_config(
|
||||||
address=server.host,
|
address=server.host,
|
||||||
port=7777,
|
port=port or 7777,
|
||||||
ipv6=':' in server.host,
|
ipv6=':' in server.host,
|
||||||
)
|
)
|
||||||
with mock.patch(
|
with mock.patch(
|
||||||
'uds_tunnel.tunnel.TunnelProtocol._readFromUDS',
|
'uds_tunnel.tunnel.TunnelProtocol._read_from_uds',
|
||||||
new_callable=tools.AsyncMock,
|
new_callable=tools.AsyncMock,
|
||||||
) as m:
|
) as m:
|
||||||
m.return_value = conf.UDS_GET_TICKET_RESPONSE(server.host, server.port)
|
m.return_value = conf.UDS_GET_TICKET_RESPONSE(server.host, server.port)
|
||||||
@ -241,6 +247,26 @@ async def create_test_tunnel(
|
|||||||
await tunnel_server.wait_closed()
|
await tunnel_server.wait_closed()
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.asynccontextmanager
|
||||||
|
async def create_fake_broker_server(
|
||||||
|
response: typing.Mapping[str, typing.Any], port: int = 44443
|
||||||
|
) -> typing.AsyncGenerator[None, None]:
|
||||||
|
# crate a fake broker server
|
||||||
|
# Ignores request, and sends response
|
||||||
|
|
||||||
|
resp: bytes = b'HTTP/1.1 200 OK\r\n\r\n' + json.dumps(response).encode()
|
||||||
|
|
||||||
|
def callback(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
|
||||||
|
writer.write(resp)
|
||||||
|
writer.close()
|
||||||
|
|
||||||
|
async with tools.AsyncTCPServer(port=port, response=resp) as server:
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
pass # nothing to do
|
||||||
|
|
||||||
|
|
||||||
@contextlib.asynccontextmanager
|
@contextlib.asynccontextmanager
|
||||||
async def open_tunnel_client(
|
async def open_tunnel_client(
|
||||||
cfg: 'config.ConfigurationType',
|
cfg: 'config.ConfigurationType',
|
||||||
@ -284,7 +310,7 @@ async def tunnel_app_runner(
|
|||||||
host: typing.Optional[str] = None,
|
host: typing.Optional[str] = None,
|
||||||
port: typing.Optional[int] = None,
|
port: typing.Optional[int] = None,
|
||||||
*,
|
*,
|
||||||
args: typing.Optional[typing.List[str]] = None
|
args: typing.Optional[typing.List[str]] = None,
|
||||||
) -> typing.AsyncGenerator['Process', None]:
|
) -> typing.AsyncGenerator['Process', None]:
|
||||||
# Ensure we are on src directory
|
# Ensure we are on src directory
|
||||||
if os.path.basename(os.getcwd()) != 'src':
|
if os.path.basename(os.getcwd()) != 'src':
|
||||||
@ -312,3 +338,10 @@ async def tunnel_app_runner(
|
|||||||
if process.returncode is None:
|
if process.returncode is None:
|
||||||
process.terminate()
|
process.terminate()
|
||||||
await process.wait()
|
await process.wait()
|
||||||
|
|
||||||
|
|
||||||
|
def get_correct_ticket(length: int = consts.TICKET_LENGTH) -> bytes:
|
||||||
|
return ''.join(
|
||||||
|
random.choice(string.ascii_letters + string.digits)
|
||||||
|
for _ in range(length)
|
||||||
|
).encode()
|
||||||
|
Loading…
Reference in New Issue
Block a user