1
0
mirror of https://github.com/dkmstr/openuds.git synced 2024-12-23 17:34:17 +03:00

Added test to udstunnel man async proc

This commit is contained in:
Adolfo Gómez García 2022-12-17 21:03:00 +01:00
parent 9f4159f18d
commit 370799912f
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
10 changed files with 369 additions and 236 deletions

View File

@ -56,6 +56,7 @@ class ConfigurationType(typing.NamedTuple):
ssl_certificate: str
ssl_certificate_key: str
ssl_password: str
ssl_ciphers: str
ssl_dhparam: str
@ -122,6 +123,7 @@ def read(
workers=int(uds.get('workers', '0')) or multiprocessing.cpu_count(),
ssl_certificate=uds['ssl_certificate'],
ssl_certificate_key=uds.get('ssl_certificate_key', ''),
ssl_password=uds.get('ssl_password', ''),
ssl_ciphers=uds.get('ssl_ciphers'),
ssl_dhparam=uds.get('ssl_dhparam'),
uds_server=uds_server,

View File

@ -39,6 +39,8 @@ import ssl
import socket
import logging
from concurrent.futures import ThreadPoolExecutor
# event for stop notification
import threading
import typing
try:
@ -62,12 +64,12 @@ if typing.TYPE_CHECKING:
logger = logging.getLogger(__name__)
do_stop = False
running: threading.Event = threading.Event()
def stop_signal(signum: int, frame: typing.Any) -> None:
global do_stop
do_stop = True
global running
running.clear()
logger.debug('SIGNAL %s, frame: %s', signum, frame)
@ -119,8 +121,13 @@ async def tunnel_proc_async(
] = pipe.recv()
if msg:
return msg
except EOFError:
logger.debug('Parent process closed connection')
pipe.close()
return None, None
except Exception:
logger.exception('Receiving data from parent process')
pipe.close()
return None, None
async def run_server() -> None:
@ -129,11 +136,15 @@ async def tunnel_proc_async(
# Generate SSL context
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
args: typing.Dict[str, typing.Any] = {
'certfile': cfg.ssl_certificate,
}
if cfg.ssl_certificate_key:
context.load_cert_chain(cfg.ssl_certificate, cfg.ssl_certificate_key)
else:
context.load_cert_chain(cfg.ssl_certificate)
args['keyfile'] = cfg.ssl_certificate_key
if cfg.ssl_password:
args['password'] = cfg.ssl_password
context.load_cert_chain(**args)
if cfg.ssl_ciphers:
context.set_ciphers(cfg.ssl_ciphers)
@ -141,29 +152,37 @@ async def tunnel_proc_async(
if cfg.ssl_dhparam:
context.load_dh_params(cfg.ssl_dhparam)
while True:
address: typing.Optional[typing.Tuple[str, int]] = ('', 0)
try:
(sock, address) = await loop.run_in_executor(None, get_socket)
if not sock:
break # No more sockets, exit
logger.debug(f'CONNECTION from {address!r} (pid: {os.getpid()})')
tasks.append(asyncio.create_task(tunneler(sock, context)))
except Exception:
logger.error('NEGOTIATION ERROR from %s', address[0] if address else 'unknown')
try:
while True:
address: typing.Optional[typing.Tuple[str, int]] = ('', 0)
try:
(sock, address) = await loop.run_in_executor(None, get_socket)
if not sock:
break # No more sockets, exit
logger.debug(f'CONNECTION from {address!r} (pid: {os.getpid()})')
tasks.append(asyncio.create_task(tunneler(sock, context)))
except asyncio.CancelledError:
raise
except Exception:
logger.error('NEGOTIATION ERROR from %s', address[0] if address else 'unknown')
except asyncio.CancelledError:
pass # Stop
# create task for server
tasks.append(asyncio.create_task(run_server()))
while tasks and not do_stop:
to_wait = tasks[:] # Get a copy of the list, and clean the original
# Wait for tasks to finish
done, _ = await asyncio.wait(to_wait, return_when=asyncio.FIRST_COMPLETED)
# Remove finished tasks
for task in done:
tasks.remove(task)
if task.exception():
logger.exception('TUNNEL ERROR')
try:
while tasks and running.is_set():
to_wait = tasks[:] # Get a copy of the list, and clean the original
# Wait for tasks to finish
done, _ = await asyncio.wait(to_wait, return_when=asyncio.FIRST_COMPLETED, timeout=2)
# Remove finished tasks
for task in done:
tasks.remove(task)
if task.exception():
logger.exception('TUNNEL ERROR')
except asyncio.CancelledError:
running.clear() # ensure we stop
# If any task is still running, cancel it
for task in tasks:
@ -244,16 +263,18 @@ def tunnel_main(args: 'argparse.Namespace') -> None:
signal.signal(signal.SIGINT, stop_signal)
signal.signal(signal.SIGTERM, stop_signal)
except Exception as e:
# Signal not available on threads, and testing uses threads
# Signal not available on threads, and we use threads on tests,
# so we will ignore this because on tests signals are not important
logger.warning('Signal not available: %s', e)
stats_collector = stats.GlobalStats()
prcs = processes.Processes(tunnel_proc_async, cfg, stats_collector.ns)
running.set()
with ThreadPoolExecutor(max_workers=256) as executor:
try:
while not do_stop:
while running.is_set():
try:
client, addr = sock.accept()
logger.info('CONNECTION from %s', addr)

View File

@ -32,7 +32,7 @@ import hashlib
from unittest import TestCase
from . import fixtures
from .utils import fixtures
class TestConfigFile(TestCase):

View File

@ -30,10 +30,7 @@ Author: Adolfo Gómez, dkmaster at dkmon dot com
'''
import typing
import random
import asyncio
import contextlib
import socket
import ssl
import logging
import multiprocessing
from unittest import IsolatedAsyncioTestCase, mock
@ -41,11 +38,7 @@ from unittest import IsolatedAsyncioTestCase, mock
from udstunnel import process_connection
from uds_tunnel import tunnel, consts
from . import fixtures
from .utils import tools, certs, conf
if typing.TYPE_CHECKING:
from uds_tunnel import config
from .utils import tuntools
logger = logging.getLogger(__name__)
@ -63,16 +56,16 @@ class TestTunnel(IsolatedAsyncioTestCase):
# Send invalid commands and see what happens
# Commands are 4 bytes length, try with less and more invalid commands
consts.TIMEOUT_COMMAND = 0.1 # type: ignore # timeout is a final variable, but we need to change it for testing speed
for i in range(0, 100, 10):
# Set timeout to 1 seconds
bad_cmd = bytes(random.randint(0, 255) for _ in range(i)) # Some garbage
consts.TIMEOUT_COMMAND = 0.1 # type: ignore # timeout is a final variable, but we need to change it for testing speed
logger.info(f'Testing invalid command with {bad_cmd!r}')
async with TestTunnel.create_test_tunnel(callback=lambda x: None) as cfg:
async with tuntools.create_test_tunnel(callback=lambda x: None) as cfg:
logger_mock = mock.MagicMock()
with mock.patch('uds_tunnel.tunnel.logger', logger_mock):
# Open connection to tunnel
async with TestTunnel.open_tunnel(cfg) as (reader, writer):
async with tuntools.open_tunnel_client(cfg) as (reader, writer):
# Send data
writer.write(bad_cmd)
await writer.drain()
@ -151,79 +144,3 @@ class TestTunnel(IsolatedAsyncioTestCase):
# recv()[0] will be a copy of the socket, we don't care about it
self.assertEqual(other_conn.recv()[1], ('host', 'port'))
@staticmethod
async def create_tunnel_server(
cfg: 'config.ConfigurationType', context: 'ssl.SSLContext'
) -> 'asyncio.Server':
# Create fake proxy
proxy = mock.MagicMock()
proxy.cfg = cfg
proxy.ns = mock.MagicMock()
proxy.ns.current = 0
proxy.ns.total = 0
proxy.ns.sent = 0
proxy.ns.recv = 0
proxy.counter = 0
loop = asyncio.get_running_loop()
# Create an asyncio listen socket on cfg.listen_host:cfg.listen_port
return await loop.create_server(
lambda: tunnel.TunnelProtocol(proxy),
cfg.listen_address,
cfg.listen_port,
ssl=context,
family=socket.AF_INET6
if cfg.listen_ipv6 or ':' in cfg.listen_address
else socket.AF_INET,
)
@staticmethod
@contextlib.asynccontextmanager
async def create_test_tunnel(
*, callback: typing.Callable[[bytes], None]
) -> typing.AsyncGenerator['config.ConfigurationType', None]:
# Generate a listening server for testing tunnel
# Prepare the end of the tunnel
async with tools.AsyncTCPServer(port=54876, callback=callback) as server:
# Create a tunnel to localhost 13579
# SSl cert for tunnel server
with certs.ssl_context(server.host) as (ssl_ctx, _):
_, cfg = fixtures.get_config(
address=server.host,
port=7777,
)
with mock.patch(
'uds_tunnel.tunnel.TunnelProtocol._readFromUDS',
new_callable=tools.AsyncMock,
) as m:
m.return_value = conf.UDS_GET_TICKET_RESPONSE(
server.host, server.port
)
tunnel_server = await TestTunnel.create_tunnel_server(cfg, ssl_ctx)
yield cfg
tunnel_server.close()
await tunnel_server.wait_closed()
@staticmethod
@contextlib.asynccontextmanager
async def open_tunnel(
cfg: 'config.ConfigurationType',
) -> typing.AsyncGenerator[
typing.Tuple[asyncio.StreamReader, asyncio.StreamWriter], None
]:
"""opens an ssl socket to the tunnel server"""
if cfg.listen_ipv6 or ':' in cfg.listen_address:
family = socket.AF_INET6
else:
family = socket.AF_INET
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
context.check_hostname = False
context.verify_mode = ssl.CERT_NONE
reader, writer = await asyncio.open_connection(
cfg.listen_address, cfg.listen_port, ssl=context, family=family
)
yield reader, writer
writer.close()
await writer.wait_closed()

View File

@ -1,111 +0,0 @@
# -*- coding: utf-8 -*-
#
# Copyright (c) 2022 Virtual Cable S.L.U.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of Virtual Cable S.L. nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
'''
Author: Adolfo Gómez, dkmaster at dkmon dot com
'''
import typing
import random
import asyncio
import contextlib
import io
import socket
import ssl
import logging
import multiprocessing
import tempfile
import threading
from unittest import IsolatedAsyncioTestCase, mock
from uds_tunnel import tunnel, consts
import udstunnel
from . import fixtures
from .utils import tools, certs, conf
if typing.TYPE_CHECKING:
from uds_tunnel import config
logger = logging.getLogger(__name__)
class TestTunnel(IsolatedAsyncioTestCase):
@staticmethod
@contextlib.contextmanager
def create_tunnel_thread(
listen_host: str,
listen_port: int,
remote_host: str,
remote_port: int,
*,
workers: int = 1
) -> typing.Generator[None, None, None]:
# Create the ssl cert
cert, key, password = certs.selfSignedCert(listen_host)
# Create the certificate file on /tmp
with tempfile.NamedTemporaryFile() as cert_file:
cert_file.write(cert.encode())
cert_file.write(key.encode())
cert_file.flush()
# Config file for the tunnel, ignore readed
values, _ = fixtures.get_config(
address=listen_host,
port=listen_port,
ssl_certificate=cert_file.name,
ssl_certificate_key='',
ssl_ciphers='',
ssl_dhparam='',
workers=workers,
)
args = mock.MagicMock()
args.config = io.StringIO(fixtures.TEST_CONFIG.format(**values))
args.ipv6 = ':' in listen_host
with mock.patch(
'uds_tunnel.tunnel.TunnelProtocol._readFromUDS',
new_callable=tools.AsyncMock,
) as m:
m.return_value = conf.UDS_GET_TICKET_RESPONSE(remote_host, remote_port)
# Create a thread to run the tunnel, udstunnel.tunnel_main will block
# until the tunnel is closed
thread = threading.Thread(target=udstunnel.tunnel_main, args=(args,))
thread.start()
yield
# Signal stop to thead
udstunnel.do_stop = True
# Wait for thread to finish
thread.join()
async def test_tunnel_full(self) -> None:
with self.create_tunnel_thread(
'127.0.0.1', 7777, '127.0.0.1', 12345, workers=1
):
await asyncio.sleep(4)

View File

@ -37,7 +37,7 @@ from unittest import IsolatedAsyncioTestCase, mock
from uds_tunnel import tunnel, consts
from . import fixtures
from .utils import fixtures
from .utils import tools, conf
logger = logging.getLogger(__name__)

View File

@ -0,0 +1,67 @@
# -*- coding: utf-8 -*-
#
# Copyright (c) 2022 Virtual Cable S.L.U.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of Virtual Cable S.L. nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
'''
Author: Adolfo Gómez, dkmaster at dkmon dot com
'''
import typing
import random
import asyncio
import logging
from unittest import IsolatedAsyncioTestCase, mock
from uds_tunnel import consts
from .utils import tuntools
logger = logging.getLogger(__name__)
class TestUDSTunnel(IsolatedAsyncioTestCase):
async def test_tunnel_fail_cmd_full(self) -> None:
consts.TIMEOUT_COMMAND = 0.1 # type: ignore # timeout is a final variable, but we need to change it for testing speed
for i in range(0, 100, 10):
# Set timeout to 1 seconds
bad_cmd = bytes(random.randint(0, 255) for _ in range(i)) # Some garbage
logger.info(f'Testing invalid command with {bad_cmd!r}')
async with tuntools.create_tunnel_proc(
'127.0.0.1', 7777, '127.0.0.1', 12345, workers=1
) as cfg:
# On full, we need the handshake to be done, before connecting
async with tuntools.open_tunnel_client(cfg, use_tunnel_handshake=True) as (creader, cwriter):
cwriter.write(bad_cmd)
await cwriter.drain()
# Read response
data = await creader.read(1024)
# if len(bad_cmd) < consts.COMMAND_LENGTH, response will be RESPONSE_ERROR_TIMEOUT
if len(bad_cmd) >= consts.COMMAND_LENGTH:
self.assertEqual(data, consts.RESPONSE_ERROR_COMMAND)
else:
self.assertEqual(data, consts.RESPONSE_ERROR_TIMEOUT)

View File

@ -46,7 +46,7 @@ from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
def selfSignedCert(ip: str) -> typing.Tuple[str, str, str]:
def selfSignedCert(ip: str, use_password: bool = True) -> typing.Tuple[str, str, str]:
key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048,
@ -72,14 +72,16 @@ def selfSignedCert(ip: str) -> typing.Tuple[str, str, str]:
.add_extension(san, False)
.sign(key, hashes.SHA256(), default_backend())
)
args: typing.Dict[str, typing.Any] = {
'encoding': serialization.Encoding.PEM,
'format': serialization.PrivateFormat.TraditionalOpenSSL,
}
if use_password:
args['encryption_algorithm'] = serialization.BestAvailableEncryption(password.encode())
else:
args['encryption_algorithm'] = serialization.NoEncryption()
return (
key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.BestAvailableEncryption(
password.encode()
),
key.private_bytes(**args
).decode(),
cert.public_bytes(encoding=serialization.Encoding.PEM).decode(),
password,

View File

@ -0,0 +1,235 @@
# -*- coding: utf-8 -*-
#
# Copyright (c) 2022 Virtual Cable S.L.U.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of Virtual Cable S.L. nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
'''
Author: Adolfo Gómez, dkmaster at dkmon dot com
'''
import asyncio
import contextlib
import io
import logging
import socket
import ssl
import tempfile
import threading
import random
import typing
from unittest import mock
import multiprocessing
import udstunnel
from uds_tunnel import consts, tunnel, stats
from . import certs, conf, fixtures, tools
if typing.TYPE_CHECKING:
from uds_tunnel import config
from multiprocessing.connection import Connection
logger = logging.getLogger(__name__)
@contextlib.asynccontextmanager
async def create_tunnel_proc(
listen_host: str,
listen_port: int,
remote_host: str,
remote_port: int,
*,
workers: int = 1
) -> typing.AsyncGenerator['config.ConfigurationType', None]:
# Create the ssl cert
cert, key, password = certs.selfSignedCert(listen_host, use_password=False)
# Create the certificate file on /tmp
cert_file = '/tmp/tunnel_full_cert.pem'
with open(cert_file, 'w') as f:
f.write(key)
f.write(cert)
# Config file for the tunnel, ignore readed
values, cfg = fixtures.get_config(
address=listen_host,
port=listen_port,
ssl_certificate=cert_file,
ssl_certificate_key='',
ssl_password=password,
ssl_ciphers='',
ssl_dhparam='',
workers=workers,
)
args = mock.MagicMock()
args.config = io.StringIO(fixtures.TEST_CONFIG.format(**values))
args.ipv6 = ':' in listen_host
with mock.patch(
'uds_tunnel.tunnel.TunnelProtocol._readFromUDS',
new_callable=tools.AsyncMock,
) as m:
m.return_value = conf.UDS_GET_TICKET_RESPONSE(remote_host, remote_port)
# Stats collector
gs = stats.GlobalStats()
# Pipe to send data to tunnel
own_end, other_end = multiprocessing.Pipe()
# Set running flag
udstunnel.running.set()
# Create the tunnel task
task = asyncio.create_task(udstunnel.tunnel_proc_async(other_end, cfg, gs.ns))
# Create a small asyncio server that reads the handshake,
# and sends the socket to the tunnel_proc_async using the pipe
# the pipe message will be typing.Tuple[socket.socket, typing.Tuple[str, int]]
# socket and address
async def client_connected_db(reader, writer):
# Read the handshake
data = await reader.read(1024)
# For testing, we ignore the handshake value
# Send the socket to the tunnel
own_end.send(
(
writer.get_extra_info('socket').dup(),
writer.get_extra_info('peername'),
)
)
# Close the socket
writer.close()
server = await asyncio.start_server(
client_connected_db,
listen_host,
listen_port,
)
try:
yield cfg
finally:
# Close the pipe (both ends)
own_end.close()
task.cancel()
# wait for the task to finish
await task
server.close()
await server.wait_closed()
logger.info('Server closed')
async def create_tunnel_server(
cfg: 'config.ConfigurationType', context: 'ssl.SSLContext'
) -> 'asyncio.Server':
# Create fake proxy
proxy = mock.MagicMock()
proxy.cfg = cfg
proxy.ns = mock.MagicMock()
proxy.ns.current = 0
proxy.ns.total = 0
proxy.ns.sent = 0
proxy.ns.recv = 0
proxy.counter = 0
loop = asyncio.get_running_loop()
# Create an asyncio listen socket on cfg.listen_host:cfg.listen_port
return await loop.create_server(
lambda: tunnel.TunnelProtocol(proxy),
cfg.listen_address,
cfg.listen_port,
ssl=context,
family=socket.AF_INET6
if cfg.listen_ipv6 or ':' in cfg.listen_address
else socket.AF_INET,
)
@contextlib.asynccontextmanager
async def create_test_tunnel(
*, callback: typing.Callable[[bytes], None]
) -> typing.AsyncGenerator['config.ConfigurationType', None]:
# Generate a listening server for testing tunnel
# Prepare the end of the tunnel
async with tools.AsyncTCPServer(port=54876, callback=callback) as server:
# Create a tunnel to localhost 13579
# SSl cert for tunnel server
with certs.ssl_context(server.host) as (ssl_ctx, _):
_, cfg = fixtures.get_config(
address=server.host,
port=7777,
)
with mock.patch(
'uds_tunnel.tunnel.TunnelProtocol._readFromUDS',
new_callable=tools.AsyncMock,
) as m:
m.return_value = conf.UDS_GET_TICKET_RESPONSE(server.host, server.port)
tunnel_server = await create_tunnel_server(cfg, ssl_ctx)
try:
yield cfg
finally:
tunnel_server.close()
await tunnel_server.wait_closed()
@contextlib.asynccontextmanager
async def open_tunnel_client(
cfg: 'config.ConfigurationType',
use_tunnel_handshake: bool = False,
) -> typing.AsyncGenerator[
typing.Tuple[asyncio.StreamReader, asyncio.StreamWriter], None
]:
"""opens an ssl socket to the tunnel server"""
loop = asyncio.get_running_loop()
if cfg.listen_ipv6 or ':' in cfg.listen_address:
family = socket.AF_INET6
else:
family = socket.AF_INET
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
context.check_hostname = False
context.verify_mode = ssl.CERT_NONE
if not use_tunnel_handshake:
reader, writer = await asyncio.open_connection(
cfg.listen_address, cfg.listen_port, ssl=context, family=family
)
else:
# Open the socket, send handshake and then upgrade to ssl, non blocking
sock = socket.socket(family, socket.SOCK_STREAM)
# Set socket to non blocking
sock.setblocking(False)
await loop.sock_connect(sock, (cfg.listen_address, cfg.listen_port))
await loop.sock_sendall(sock, consts.HANDSHAKE_V1)
# upgrade to ssl
reader, writer = await asyncio.open_connection(
sock=sock, ssl=context, server_hostname=cfg.listen_address
)
try:
yield reader, writer
finally:
writer.close()
await writer.wait_closed()