mirror of
https://github.com/dkmstr/openuds.git
synced 2024-12-23 17:34:17 +03:00
Added test to udstunnel man async proc
This commit is contained in:
parent
9f4159f18d
commit
370799912f
@ -56,6 +56,7 @@ class ConfigurationType(typing.NamedTuple):
|
||||
|
||||
ssl_certificate: str
|
||||
ssl_certificate_key: str
|
||||
ssl_password: str
|
||||
ssl_ciphers: str
|
||||
ssl_dhparam: str
|
||||
|
||||
@ -122,6 +123,7 @@ def read(
|
||||
workers=int(uds.get('workers', '0')) or multiprocessing.cpu_count(),
|
||||
ssl_certificate=uds['ssl_certificate'],
|
||||
ssl_certificate_key=uds.get('ssl_certificate_key', ''),
|
||||
ssl_password=uds.get('ssl_password', ''),
|
||||
ssl_ciphers=uds.get('ssl_ciphers'),
|
||||
ssl_dhparam=uds.get('ssl_dhparam'),
|
||||
uds_server=uds_server,
|
||||
|
@ -39,6 +39,8 @@ import ssl
|
||||
import socket
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
# event for stop notification
|
||||
import threading
|
||||
import typing
|
||||
|
||||
try:
|
||||
@ -62,12 +64,12 @@ if typing.TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
do_stop = False
|
||||
running: threading.Event = threading.Event()
|
||||
|
||||
|
||||
def stop_signal(signum: int, frame: typing.Any) -> None:
|
||||
global do_stop
|
||||
do_stop = True
|
||||
global running
|
||||
running.clear()
|
||||
logger.debug('SIGNAL %s, frame: %s', signum, frame)
|
||||
|
||||
|
||||
@ -119,8 +121,13 @@ async def tunnel_proc_async(
|
||||
] = pipe.recv()
|
||||
if msg:
|
||||
return msg
|
||||
except EOFError:
|
||||
logger.debug('Parent process closed connection')
|
||||
pipe.close()
|
||||
return None, None
|
||||
except Exception:
|
||||
logger.exception('Receiving data from parent process')
|
||||
pipe.close()
|
||||
return None, None
|
||||
|
||||
async def run_server() -> None:
|
||||
@ -129,11 +136,15 @@ async def tunnel_proc_async(
|
||||
|
||||
# Generate SSL context
|
||||
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
|
||||
args: typing.Dict[str, typing.Any] = {
|
||||
'certfile': cfg.ssl_certificate,
|
||||
}
|
||||
if cfg.ssl_certificate_key:
|
||||
context.load_cert_chain(cfg.ssl_certificate, cfg.ssl_certificate_key)
|
||||
else:
|
||||
context.load_cert_chain(cfg.ssl_certificate)
|
||||
args['keyfile'] = cfg.ssl_certificate_key
|
||||
if cfg.ssl_password:
|
||||
args['password'] = cfg.ssl_password
|
||||
|
||||
context.load_cert_chain(**args)
|
||||
|
||||
if cfg.ssl_ciphers:
|
||||
context.set_ciphers(cfg.ssl_ciphers)
|
||||
@ -141,29 +152,37 @@ async def tunnel_proc_async(
|
||||
if cfg.ssl_dhparam:
|
||||
context.load_dh_params(cfg.ssl_dhparam)
|
||||
|
||||
while True:
|
||||
address: typing.Optional[typing.Tuple[str, int]] = ('', 0)
|
||||
try:
|
||||
(sock, address) = await loop.run_in_executor(None, get_socket)
|
||||
if not sock:
|
||||
break # No more sockets, exit
|
||||
logger.debug(f'CONNECTION from {address!r} (pid: {os.getpid()})')
|
||||
tasks.append(asyncio.create_task(tunneler(sock, context)))
|
||||
except Exception:
|
||||
logger.error('NEGOTIATION ERROR from %s', address[0] if address else 'unknown')
|
||||
try:
|
||||
while True:
|
||||
address: typing.Optional[typing.Tuple[str, int]] = ('', 0)
|
||||
try:
|
||||
(sock, address) = await loop.run_in_executor(None, get_socket)
|
||||
if not sock:
|
||||
break # No more sockets, exit
|
||||
logger.debug(f'CONNECTION from {address!r} (pid: {os.getpid()})')
|
||||
tasks.append(asyncio.create_task(tunneler(sock, context)))
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.error('NEGOTIATION ERROR from %s', address[0] if address else 'unknown')
|
||||
except asyncio.CancelledError:
|
||||
pass # Stop
|
||||
|
||||
# create task for server
|
||||
tasks.append(asyncio.create_task(run_server()))
|
||||
|
||||
while tasks and not do_stop:
|
||||
to_wait = tasks[:] # Get a copy of the list, and clean the original
|
||||
# Wait for tasks to finish
|
||||
done, _ = await asyncio.wait(to_wait, return_when=asyncio.FIRST_COMPLETED)
|
||||
# Remove finished tasks
|
||||
for task in done:
|
||||
tasks.remove(task)
|
||||
if task.exception():
|
||||
logger.exception('TUNNEL ERROR')
|
||||
try:
|
||||
while tasks and running.is_set():
|
||||
to_wait = tasks[:] # Get a copy of the list, and clean the original
|
||||
# Wait for tasks to finish
|
||||
done, _ = await asyncio.wait(to_wait, return_when=asyncio.FIRST_COMPLETED, timeout=2)
|
||||
# Remove finished tasks
|
||||
for task in done:
|
||||
tasks.remove(task)
|
||||
if task.exception():
|
||||
logger.exception('TUNNEL ERROR')
|
||||
except asyncio.CancelledError:
|
||||
running.clear() # ensure we stop
|
||||
|
||||
# If any task is still running, cancel it
|
||||
for task in tasks:
|
||||
@ -244,16 +263,18 @@ def tunnel_main(args: 'argparse.Namespace') -> None:
|
||||
signal.signal(signal.SIGINT, stop_signal)
|
||||
signal.signal(signal.SIGTERM, stop_signal)
|
||||
except Exception as e:
|
||||
# Signal not available on threads, and testing uses threads
|
||||
# Signal not available on threads, and we use threads on tests,
|
||||
# so we will ignore this because on tests signals are not important
|
||||
logger.warning('Signal not available: %s', e)
|
||||
|
||||
stats_collector = stats.GlobalStats()
|
||||
|
||||
prcs = processes.Processes(tunnel_proc_async, cfg, stats_collector.ns)
|
||||
|
||||
running.set()
|
||||
with ThreadPoolExecutor(max_workers=256) as executor:
|
||||
try:
|
||||
while not do_stop:
|
||||
while running.is_set():
|
||||
try:
|
||||
client, addr = sock.accept()
|
||||
logger.info('CONNECTION from %s', addr)
|
||||
|
@ -32,7 +32,7 @@ import hashlib
|
||||
|
||||
from unittest import TestCase
|
||||
|
||||
from . import fixtures
|
||||
from .utils import fixtures
|
||||
|
||||
|
||||
class TestConfigFile(TestCase):
|
||||
|
@ -30,10 +30,7 @@ Author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
'''
|
||||
import typing
|
||||
import random
|
||||
import asyncio
|
||||
import contextlib
|
||||
import socket
|
||||
import ssl
|
||||
import logging
|
||||
import multiprocessing
|
||||
from unittest import IsolatedAsyncioTestCase, mock
|
||||
@ -41,11 +38,7 @@ from unittest import IsolatedAsyncioTestCase, mock
|
||||
from udstunnel import process_connection
|
||||
from uds_tunnel import tunnel, consts
|
||||
|
||||
from . import fixtures
|
||||
from .utils import tools, certs, conf
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from uds_tunnel import config
|
||||
from .utils import tuntools
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -63,16 +56,16 @@ class TestTunnel(IsolatedAsyncioTestCase):
|
||||
|
||||
# Send invalid commands and see what happens
|
||||
# Commands are 4 bytes length, try with less and more invalid commands
|
||||
consts.TIMEOUT_COMMAND = 0.1 # type: ignore # timeout is a final variable, but we need to change it for testing speed
|
||||
for i in range(0, 100, 10):
|
||||
# Set timeout to 1 seconds
|
||||
bad_cmd = bytes(random.randint(0, 255) for _ in range(i)) # Some garbage
|
||||
consts.TIMEOUT_COMMAND = 0.1 # type: ignore # timeout is a final variable, but we need to change it for testing speed
|
||||
logger.info(f'Testing invalid command with {bad_cmd!r}')
|
||||
async with TestTunnel.create_test_tunnel(callback=lambda x: None) as cfg:
|
||||
async with tuntools.create_test_tunnel(callback=lambda x: None) as cfg:
|
||||
logger_mock = mock.MagicMock()
|
||||
with mock.patch('uds_tunnel.tunnel.logger', logger_mock):
|
||||
# Open connection to tunnel
|
||||
async with TestTunnel.open_tunnel(cfg) as (reader, writer):
|
||||
async with tuntools.open_tunnel_client(cfg) as (reader, writer):
|
||||
# Send data
|
||||
writer.write(bad_cmd)
|
||||
await writer.drain()
|
||||
@ -151,79 +144,3 @@ class TestTunnel(IsolatedAsyncioTestCase):
|
||||
# recv()[0] will be a copy of the socket, we don't care about it
|
||||
self.assertEqual(other_conn.recv()[1], ('host', 'port'))
|
||||
|
||||
@staticmethod
|
||||
async def create_tunnel_server(
|
||||
cfg: 'config.ConfigurationType', context: 'ssl.SSLContext'
|
||||
) -> 'asyncio.Server':
|
||||
# Create fake proxy
|
||||
proxy = mock.MagicMock()
|
||||
proxy.cfg = cfg
|
||||
proxy.ns = mock.MagicMock()
|
||||
proxy.ns.current = 0
|
||||
proxy.ns.total = 0
|
||||
proxy.ns.sent = 0
|
||||
proxy.ns.recv = 0
|
||||
proxy.counter = 0
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
# Create an asyncio listen socket on cfg.listen_host:cfg.listen_port
|
||||
return await loop.create_server(
|
||||
lambda: tunnel.TunnelProtocol(proxy),
|
||||
cfg.listen_address,
|
||||
cfg.listen_port,
|
||||
ssl=context,
|
||||
family=socket.AF_INET6
|
||||
if cfg.listen_ipv6 or ':' in cfg.listen_address
|
||||
else socket.AF_INET,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@contextlib.asynccontextmanager
|
||||
async def create_test_tunnel(
|
||||
*, callback: typing.Callable[[bytes], None]
|
||||
) -> typing.AsyncGenerator['config.ConfigurationType', None]:
|
||||
# Generate a listening server for testing tunnel
|
||||
# Prepare the end of the tunnel
|
||||
async with tools.AsyncTCPServer(port=54876, callback=callback) as server:
|
||||
# Create a tunnel to localhost 13579
|
||||
# SSl cert for tunnel server
|
||||
with certs.ssl_context(server.host) as (ssl_ctx, _):
|
||||
_, cfg = fixtures.get_config(
|
||||
address=server.host,
|
||||
port=7777,
|
||||
)
|
||||
with mock.patch(
|
||||
'uds_tunnel.tunnel.TunnelProtocol._readFromUDS',
|
||||
new_callable=tools.AsyncMock,
|
||||
) as m:
|
||||
m.return_value = conf.UDS_GET_TICKET_RESPONSE(
|
||||
server.host, server.port
|
||||
)
|
||||
|
||||
tunnel_server = await TestTunnel.create_tunnel_server(cfg, ssl_ctx)
|
||||
yield cfg
|
||||
tunnel_server.close()
|
||||
await tunnel_server.wait_closed()
|
||||
|
||||
@staticmethod
|
||||
@contextlib.asynccontextmanager
|
||||
async def open_tunnel(
|
||||
cfg: 'config.ConfigurationType',
|
||||
) -> typing.AsyncGenerator[
|
||||
typing.Tuple[asyncio.StreamReader, asyncio.StreamWriter], None
|
||||
]:
|
||||
"""opens an ssl socket to the tunnel server"""
|
||||
if cfg.listen_ipv6 or ':' in cfg.listen_address:
|
||||
family = socket.AF_INET6
|
||||
else:
|
||||
family = socket.AF_INET
|
||||
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
context.check_hostname = False
|
||||
context.verify_mode = ssl.CERT_NONE
|
||||
reader, writer = await asyncio.open_connection(
|
||||
cfg.listen_address, cfg.listen_port, ssl=context, family=family
|
||||
)
|
||||
yield reader, writer
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
|
@ -1,111 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright (c) 2022 Virtual Cable S.L.U.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification,
|
||||
# are permitted provided that the following conditions are met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright notice,
|
||||
# this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
# * Neither the name of Virtual Cable S.L. nor the names of its contributors
|
||||
# may be used to endorse or promote products derived from this software
|
||||
# without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
'''
|
||||
Author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
'''
|
||||
import typing
|
||||
import random
|
||||
import asyncio
|
||||
import contextlib
|
||||
import io
|
||||
import socket
|
||||
import ssl
|
||||
import logging
|
||||
import multiprocessing
|
||||
import tempfile
|
||||
import threading
|
||||
from unittest import IsolatedAsyncioTestCase, mock
|
||||
|
||||
from uds_tunnel import tunnel, consts
|
||||
import udstunnel
|
||||
|
||||
|
||||
from . import fixtures
|
||||
from .utils import tools, certs, conf
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from uds_tunnel import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TestTunnel(IsolatedAsyncioTestCase):
|
||||
@staticmethod
|
||||
@contextlib.contextmanager
|
||||
def create_tunnel_thread(
|
||||
listen_host: str,
|
||||
listen_port: int,
|
||||
remote_host: str,
|
||||
remote_port: int,
|
||||
*,
|
||||
workers: int = 1
|
||||
) -> typing.Generator[None, None, None]:
|
||||
# Create the ssl cert
|
||||
cert, key, password = certs.selfSignedCert(listen_host)
|
||||
# Create the certificate file on /tmp
|
||||
with tempfile.NamedTemporaryFile() as cert_file:
|
||||
cert_file.write(cert.encode())
|
||||
cert_file.write(key.encode())
|
||||
cert_file.flush()
|
||||
|
||||
# Config file for the tunnel, ignore readed
|
||||
values, _ = fixtures.get_config(
|
||||
address=listen_host,
|
||||
port=listen_port,
|
||||
ssl_certificate=cert_file.name,
|
||||
ssl_certificate_key='',
|
||||
ssl_ciphers='',
|
||||
ssl_dhparam='',
|
||||
workers=workers,
|
||||
)
|
||||
args = mock.MagicMock()
|
||||
args.config = io.StringIO(fixtures.TEST_CONFIG.format(**values))
|
||||
args.ipv6 = ':' in listen_host
|
||||
|
||||
with mock.patch(
|
||||
'uds_tunnel.tunnel.TunnelProtocol._readFromUDS',
|
||||
new_callable=tools.AsyncMock,
|
||||
) as m:
|
||||
m.return_value = conf.UDS_GET_TICKET_RESPONSE(remote_host, remote_port)
|
||||
|
||||
# Create a thread to run the tunnel, udstunnel.tunnel_main will block
|
||||
# until the tunnel is closed
|
||||
thread = threading.Thread(target=udstunnel.tunnel_main, args=(args,))
|
||||
thread.start()
|
||||
yield
|
||||
# Signal stop to thead
|
||||
udstunnel.do_stop = True
|
||||
|
||||
# Wait for thread to finish
|
||||
thread.join()
|
||||
|
||||
async def test_tunnel_full(self) -> None:
|
||||
with self.create_tunnel_thread(
|
||||
'127.0.0.1', 7777, '127.0.0.1', 12345, workers=1
|
||||
):
|
||||
await asyncio.sleep(4)
|
@ -37,7 +37,7 @@ from unittest import IsolatedAsyncioTestCase, mock
|
||||
|
||||
from uds_tunnel import tunnel, consts
|
||||
|
||||
from . import fixtures
|
||||
from .utils import fixtures
|
||||
from .utils import tools, conf
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
67
tunnel-server/test/test_udstunnel.py
Normal file
67
tunnel-server/test/test_udstunnel.py
Normal file
@ -0,0 +1,67 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright (c) 2022 Virtual Cable S.L.U.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification,
|
||||
# are permitted provided that the following conditions are met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright notice,
|
||||
# this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
# * Neither the name of Virtual Cable S.L. nor the names of its contributors
|
||||
# may be used to endorse or promote products derived from this software
|
||||
# without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
'''
|
||||
Author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
'''
|
||||
import typing
|
||||
import random
|
||||
import asyncio
|
||||
import logging
|
||||
from unittest import IsolatedAsyncioTestCase, mock
|
||||
|
||||
from uds_tunnel import consts
|
||||
|
||||
from .utils import tuntools
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TestUDSTunnel(IsolatedAsyncioTestCase):
|
||||
|
||||
async def test_tunnel_fail_cmd_full(self) -> None:
|
||||
consts.TIMEOUT_COMMAND = 0.1 # type: ignore # timeout is a final variable, but we need to change it for testing speed
|
||||
for i in range(0, 100, 10):
|
||||
# Set timeout to 1 seconds
|
||||
bad_cmd = bytes(random.randint(0, 255) for _ in range(i)) # Some garbage
|
||||
logger.info(f'Testing invalid command with {bad_cmd!r}')
|
||||
async with tuntools.create_tunnel_proc(
|
||||
'127.0.0.1', 7777, '127.0.0.1', 12345, workers=1
|
||||
) as cfg:
|
||||
# On full, we need the handshake to be done, before connecting
|
||||
async with tuntools.open_tunnel_client(cfg, use_tunnel_handshake=True) as (creader, cwriter):
|
||||
cwriter.write(bad_cmd)
|
||||
await cwriter.drain()
|
||||
# Read response
|
||||
data = await creader.read(1024)
|
||||
# if len(bad_cmd) < consts.COMMAND_LENGTH, response will be RESPONSE_ERROR_TIMEOUT
|
||||
if len(bad_cmd) >= consts.COMMAND_LENGTH:
|
||||
self.assertEqual(data, consts.RESPONSE_ERROR_COMMAND)
|
||||
else:
|
||||
self.assertEqual(data, consts.RESPONSE_ERROR_TIMEOUT)
|
||||
|
||||
|
@ -46,7 +46,7 @@ from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
|
||||
|
||||
def selfSignedCert(ip: str) -> typing.Tuple[str, str, str]:
|
||||
def selfSignedCert(ip: str, use_password: bool = True) -> typing.Tuple[str, str, str]:
|
||||
key = rsa.generate_private_key(
|
||||
public_exponent=65537,
|
||||
key_size=2048,
|
||||
@ -72,14 +72,16 @@ def selfSignedCert(ip: str) -> typing.Tuple[str, str, str]:
|
||||
.add_extension(san, False)
|
||||
.sign(key, hashes.SHA256(), default_backend())
|
||||
)
|
||||
|
||||
args: typing.Dict[str, typing.Any] = {
|
||||
'encoding': serialization.Encoding.PEM,
|
||||
'format': serialization.PrivateFormat.TraditionalOpenSSL,
|
||||
}
|
||||
if use_password:
|
||||
args['encryption_algorithm'] = serialization.BestAvailableEncryption(password.encode())
|
||||
else:
|
||||
args['encryption_algorithm'] = serialization.NoEncryption()
|
||||
return (
|
||||
key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
||||
encryption_algorithm=serialization.BestAvailableEncryption(
|
||||
password.encode()
|
||||
),
|
||||
key.private_bytes(**args
|
||||
).decode(),
|
||||
cert.public_bytes(encoding=serialization.Encoding.PEM).decode(),
|
||||
password,
|
||||
|
235
tunnel-server/test/utils/tuntools.py
Normal file
235
tunnel-server/test/utils/tuntools.py
Normal file
@ -0,0 +1,235 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright (c) 2022 Virtual Cable S.L.U.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification,
|
||||
# are permitted provided that the following conditions are met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright notice,
|
||||
# this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
# * Neither the name of Virtual Cable S.L. nor the names of its contributors
|
||||
# may be used to endorse or promote products derived from this software
|
||||
# without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
'''
|
||||
Author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
'''
|
||||
import asyncio
|
||||
import contextlib
|
||||
import io
|
||||
import logging
|
||||
import socket
|
||||
import ssl
|
||||
import tempfile
|
||||
import threading
|
||||
import random
|
||||
import typing
|
||||
from unittest import mock
|
||||
import multiprocessing
|
||||
|
||||
import udstunnel
|
||||
from uds_tunnel import consts, tunnel, stats
|
||||
|
||||
from . import certs, conf, fixtures, tools
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from uds_tunnel import config
|
||||
from multiprocessing.connection import Connection
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def create_tunnel_proc(
|
||||
listen_host: str,
|
||||
listen_port: int,
|
||||
remote_host: str,
|
||||
remote_port: int,
|
||||
*,
|
||||
workers: int = 1
|
||||
) -> typing.AsyncGenerator['config.ConfigurationType', None]:
|
||||
# Create the ssl cert
|
||||
cert, key, password = certs.selfSignedCert(listen_host, use_password=False)
|
||||
# Create the certificate file on /tmp
|
||||
cert_file = '/tmp/tunnel_full_cert.pem'
|
||||
with open(cert_file, 'w') as f:
|
||||
f.write(key)
|
||||
f.write(cert)
|
||||
|
||||
# Config file for the tunnel, ignore readed
|
||||
values, cfg = fixtures.get_config(
|
||||
address=listen_host,
|
||||
port=listen_port,
|
||||
ssl_certificate=cert_file,
|
||||
ssl_certificate_key='',
|
||||
ssl_password=password,
|
||||
ssl_ciphers='',
|
||||
ssl_dhparam='',
|
||||
workers=workers,
|
||||
)
|
||||
args = mock.MagicMock()
|
||||
args.config = io.StringIO(fixtures.TEST_CONFIG.format(**values))
|
||||
args.ipv6 = ':' in listen_host
|
||||
|
||||
with mock.patch(
|
||||
'uds_tunnel.tunnel.TunnelProtocol._readFromUDS',
|
||||
new_callable=tools.AsyncMock,
|
||||
) as m:
|
||||
m.return_value = conf.UDS_GET_TICKET_RESPONSE(remote_host, remote_port)
|
||||
|
||||
# Stats collector
|
||||
gs = stats.GlobalStats()
|
||||
# Pipe to send data to tunnel
|
||||
own_end, other_end = multiprocessing.Pipe()
|
||||
|
||||
# Set running flag
|
||||
udstunnel.running.set()
|
||||
|
||||
# Create the tunnel task
|
||||
task = asyncio.create_task(udstunnel.tunnel_proc_async(other_end, cfg, gs.ns))
|
||||
|
||||
# Create a small asyncio server that reads the handshake,
|
||||
# and sends the socket to the tunnel_proc_async using the pipe
|
||||
# the pipe message will be typing.Tuple[socket.socket, typing.Tuple[str, int]]
|
||||
# socket and address
|
||||
async def client_connected_db(reader, writer):
|
||||
# Read the handshake
|
||||
data = await reader.read(1024)
|
||||
# For testing, we ignore the handshake value
|
||||
# Send the socket to the tunnel
|
||||
own_end.send(
|
||||
(
|
||||
writer.get_extra_info('socket').dup(),
|
||||
writer.get_extra_info('peername'),
|
||||
)
|
||||
)
|
||||
# Close the socket
|
||||
writer.close()
|
||||
|
||||
server = await asyncio.start_server(
|
||||
client_connected_db,
|
||||
listen_host,
|
||||
listen_port,
|
||||
)
|
||||
try:
|
||||
yield cfg
|
||||
finally:
|
||||
# Close the pipe (both ends)
|
||||
own_end.close()
|
||||
|
||||
task.cancel()
|
||||
# wait for the task to finish
|
||||
await task
|
||||
|
||||
server.close()
|
||||
await server.wait_closed()
|
||||
logger.info('Server closed')
|
||||
|
||||
|
||||
async def create_tunnel_server(
|
||||
cfg: 'config.ConfigurationType', context: 'ssl.SSLContext'
|
||||
) -> 'asyncio.Server':
|
||||
# Create fake proxy
|
||||
proxy = mock.MagicMock()
|
||||
proxy.cfg = cfg
|
||||
proxy.ns = mock.MagicMock()
|
||||
proxy.ns.current = 0
|
||||
proxy.ns.total = 0
|
||||
proxy.ns.sent = 0
|
||||
proxy.ns.recv = 0
|
||||
proxy.counter = 0
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
# Create an asyncio listen socket on cfg.listen_host:cfg.listen_port
|
||||
return await loop.create_server(
|
||||
lambda: tunnel.TunnelProtocol(proxy),
|
||||
cfg.listen_address,
|
||||
cfg.listen_port,
|
||||
ssl=context,
|
||||
family=socket.AF_INET6
|
||||
if cfg.listen_ipv6 or ':' in cfg.listen_address
|
||||
else socket.AF_INET,
|
||||
)
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def create_test_tunnel(
|
||||
*, callback: typing.Callable[[bytes], None]
|
||||
) -> typing.AsyncGenerator['config.ConfigurationType', None]:
|
||||
# Generate a listening server for testing tunnel
|
||||
# Prepare the end of the tunnel
|
||||
async with tools.AsyncTCPServer(port=54876, callback=callback) as server:
|
||||
# Create a tunnel to localhost 13579
|
||||
# SSl cert for tunnel server
|
||||
with certs.ssl_context(server.host) as (ssl_ctx, _):
|
||||
_, cfg = fixtures.get_config(
|
||||
address=server.host,
|
||||
port=7777,
|
||||
)
|
||||
with mock.patch(
|
||||
'uds_tunnel.tunnel.TunnelProtocol._readFromUDS',
|
||||
new_callable=tools.AsyncMock,
|
||||
) as m:
|
||||
m.return_value = conf.UDS_GET_TICKET_RESPONSE(server.host, server.port)
|
||||
|
||||
tunnel_server = await create_tunnel_server(cfg, ssl_ctx)
|
||||
try:
|
||||
yield cfg
|
||||
finally:
|
||||
tunnel_server.close()
|
||||
await tunnel_server.wait_closed()
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def open_tunnel_client(
|
||||
cfg: 'config.ConfigurationType',
|
||||
use_tunnel_handshake: bool = False,
|
||||
) -> typing.AsyncGenerator[
|
||||
typing.Tuple[asyncio.StreamReader, asyncio.StreamWriter], None
|
||||
]:
|
||||
"""opens an ssl socket to the tunnel server"""
|
||||
loop = asyncio.get_running_loop()
|
||||
if cfg.listen_ipv6 or ':' in cfg.listen_address:
|
||||
family = socket.AF_INET6
|
||||
else:
|
||||
family = socket.AF_INET
|
||||
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
context.check_hostname = False
|
||||
context.verify_mode = ssl.CERT_NONE
|
||||
|
||||
if not use_tunnel_handshake:
|
||||
reader, writer = await asyncio.open_connection(
|
||||
cfg.listen_address, cfg.listen_port, ssl=context, family=family
|
||||
)
|
||||
else:
|
||||
# Open the socket, send handshake and then upgrade to ssl, non blocking
|
||||
sock = socket.socket(family, socket.SOCK_STREAM)
|
||||
# Set socket to non blocking
|
||||
sock.setblocking(False)
|
||||
await loop.sock_connect(sock, (cfg.listen_address, cfg.listen_port))
|
||||
await loop.sock_sendall(sock, consts.HANDSHAKE_V1)
|
||||
# upgrade to ssl
|
||||
reader, writer = await asyncio.open_connection(
|
||||
sock=sock, ssl=context, server_hostname=cfg.listen_address
|
||||
)
|
||||
try:
|
||||
yield reader, writer
|
||||
finally:
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
Loading…
Reference in New Issue
Block a user