1
0
mirror of https://github.com/dkmstr/openuds.git synced 2025-03-11 00:58:39 +03:00

added basic structure for testing tunnel server and some tests

This commit is contained in:
Adolfo Gómez García 2022-12-16 17:57:57 +01:00
parent 406f32c2fa
commit 081dfc9995
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
12 changed files with 420 additions and 103 deletions

View File

@ -38,16 +38,29 @@ import threading
import select
import typing
import logging
import enum
from . import tools
HANDSHAKE_V1 = b'\x5AMGB\xA5\x01\x00'
BUFFER_SIZE = 1024 * 16 # Max buffer length
DEBUG = True
LISTEN_ADDRESS = '0.0.0.0' if DEBUG else '127.0.0.1'
BUFFER_SIZE: typing.Final[int] = 1024 * 16 # Max buffer length
LISTEN_ADDRESS: typing.Final[str] = '0.0.0.0' if DEBUG else '127.0.0.1'
LISTEN_ADDRESS_V6: typing.Final[str] = '::' if DEBUG else '::1'
# ForwarServer states
TUNNEL_LISTENING, TUNNEL_OPENING, TUNNEL_PROCESSING, TUNNEL_ERROR = 0, 1, 2, 3
class ForwardState(enum.IntEnum):
TUNNEL_LISTENING = 0
TUNNEL_OPENING = 1
TUNNEL_PROCESSING = 2
TUNNEL_ERROR = 3
# Some constants strings for protocol
HANDSHAKE_V1: typing.Final[bytes] = b'\x5AMGB\xA5\x01\x00'
CMD_TEST: typing.Final[bytes] = b'TEST'
CMD_OPEN: typing.Final[bytes] = b'OPEN'
RESPONSE_OK: typing.Final[bytes] = b'OK'
logger = logging.getLogger(__name__)
@ -57,6 +70,7 @@ class ForwardServer(socketserver.ThreadingTCPServer):
allow_reuse_address = True
remote: typing.Tuple[str, int]
remote_ipv6: bool
ticket: str
stop_flag: threading.Event
can_stop: bool
@ -64,7 +78,9 @@ class ForwardServer(socketserver.ThreadingTCPServer):
timer: typing.Optional[threading.Timer]
check_certificate: bool
current_connections: int
status: int
status: ForwardState
address_family = socket.AF_INET
def __init__(
self,
@ -73,14 +89,21 @@ class ForwardServer(socketserver.ThreadingTCPServer):
timeout: int = 0,
local_port: int = 0,
check_certificate: bool = True,
ipv6_listen: bool = False,
ipv6_remote: bool = False,
) -> None:
local_port = local_port or random.randrange(33000, 53000)
if ipv6_listen:
self.address_family = socket.AF_INET6
super().__init__(
server_address=(LISTEN_ADDRESS, local_port), RequestHandlerClass=Handler
server_address=(LISTEN_ADDRESS if ipv6_listen else LISTEN_ADDRESS_V6, local_port),
RequestHandlerClass=Handler,
)
self.remote = remote
self.remote_ipv6 = ipv6_remote or ':' in remote[0] # if ':' in remote address, it's ipv6 (port is [1])
self.ticket = ticket
# Negative values for timeout, means "accept always connections"
# "but if no connection is stablished on timeout (positive)"
@ -90,7 +113,7 @@ class ForwardServer(socketserver.ThreadingTCPServer):
self.stop_flag = threading.Event() # False initial
self.current_connections = 0
self.status = TUNNEL_LISTENING
self.status = ForwardState.TUNNEL_LISTENING
self.can_stop = False
timeout = abs(timeout) or 60
@ -109,7 +132,7 @@ class ForwardServer(socketserver.ThreadingTCPServer):
self.shutdown()
def connect(self) -> ssl.SSLSocket:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as rsocket:
with socket.socket(socket.AF_INET6 if self.remote_ipv6 else socket.AF_INET, socket.SOCK_STREAM) as rsocket:
logger.info('CONNECT to %s', self.remote)
rsocket.connect(self.remote)
@ -134,16 +157,16 @@ class ForwardServer(socketserver.ThreadingTCPServer):
return context.wrap_socket(rsocket, server_hostname=self.remote[0])
def check(self) -> bool:
if self.status == TUNNEL_ERROR:
if self.status == ForwardState.TUNNEL_ERROR:
return False
logger.debug('Checking tunnel availability')
try:
with self.connect() as ssl_socket:
ssl_socket.sendall(b'TEST')
ssl_socket.sendall(CMD_TEST)
resp = ssl_socket.recv(2)
if resp != b'OK':
if resp != RESPONSE_OK:
raise Exception({'Invalid tunnelresponse: {resp}'})
logger.debug('Tunnel is available!')
return True
@ -173,11 +196,11 @@ class Handler(socketserver.BaseRequestHandler):
# server: ForwardServer
def handle(self) -> None:
self.server.status = TUNNEL_OPENING
self.server.status = ForwardState.TUNNEL_OPENING
# If server processing is over time
# If server new connections processing are over time...
if self.server.stoppable:
self.server.status = TUNNEL_ERROR
self.server.status = ForwardState.TUNNEL_ERROR
logger.info('Rejected timedout connection')
self.request.close() # End connection without processing it
return
@ -189,10 +212,10 @@ class Handler(socketserver.BaseRequestHandler):
logger.debug('Ticket %s', self.server.ticket)
with self.server.connect() as ssl_socket:
# Send handhshake + command + ticket
ssl_socket.sendall(b'OPEN' + self.server.ticket.encode())
ssl_socket.sendall(CMD_OPEN + self.server.ticket.encode())
# Check response is OK
data = ssl_socket.recv(2)
if data != b'OK':
if data != RESPONSE_OK:
data += ssl_socket.recv(128)
raise Exception(
f'Error received: {data.decode(errors="ignore")}'
@ -202,7 +225,7 @@ class Handler(socketserver.BaseRequestHandler):
self.process(remote=ssl_socket)
except Exception as e:
logger.error(f'Error connecting to {self.server.remote!s}: {e!s}')
self.server.status = TUNNEL_ERROR
self.server.status = ForwardState.TUNNEL_ERROR
self.server.stop()
finally:
self.server.current_connections -= 1
@ -212,7 +235,7 @@ class Handler(socketserver.BaseRequestHandler):
# Processes data forwarding
def process(self, remote: ssl.SSLSocket):
self.server.status = TUNNEL_PROCESSING
self.server.status = ForwardState.TUNNEL_PROCESSING
logger.debug('Processing tunnel with ticket %s', self.server.ticket)
# Process data until stop requested or connection closed
try:

View File

@ -50,6 +50,7 @@ class ConfigurationType(typing.NamedTuple):
listen_address: str
listen_port: int
listen_ipv6: bool
workers: int
@ -117,9 +118,10 @@ def read(
log_number=int(uds.get('lognumber', '3')),
listen_address=uds.get('address', '0.0.0.0'),
listen_port=int(uds.get('port', '443')),
listen_ipv6=uds.get('ipv6', 'false').lower() == 'true',
workers=int(uds.get('workers', '0')) or multiprocessing.cpu_count(),
ssl_certificate=uds['ssl_certificate'],
ssl_certificate_key=uds['ssl_certificate_key'],
ssl_certificate_key=uds.get('ssl_certificate_key', ''),
ssl_ciphers=uds.get('ssl_ciphers'),
ssl_dhparam=uds.get('ssl_dhparam'),
uds_server=uds_server,

View File

@ -28,34 +28,42 @@
'''
Author: Adolfo Gómez, dkmaster at dkmon dot com
'''
import typing
DEBUG = True
if DEBUG:
CONFIGFILE = 'udstunnel.conf'
LOGFORMAT = '%(levelname)s %(asctime)s %(message)s'
else:
CONFIGFILE = '/etc/udstunnel.conf'
LOGFORMAT = '%(levelname)s %(asctime)s %(message)s'
CONFIGFILE: typing.Final[str] = '/etc/udstunnel.conf' if not DEBUG else 'udstunnel.conf'
LOGFORMAT: typing.Final[str] = '%(levelname)s %(asctime)s %(message)s' if not DEBUG else '%(levelname)s %(asctime)s %(message)s'
# MAX Length of read buffer for proxyed requests
BUFFER_SIZE = 1024 * 16
BUFFER_SIZE: typing.Final[int] = 1024 * 16
# Handshake for conversation start
HANDSHAKE_V1 = b'\x5AMGB\xA5\x01\x00'
HANDSHAKE_V1: typing.Final[bytes] = b'\x5AMGB\xA5\x01\x00'
# Ticket length
TICKET_LENGTH = 48
TICKET_LENGTH: typing.Final[int] = 48
# Max Admin password length (stats basically right now)
PASSWORD_LENGTH = 64
PASSWORD_LENGTH: typing.Final[int] = 64
# Bandwidth calc time lapse
BANDWIDTH_TIME = 10
BANDWIDTH_TIME: typing.Final[int] = 10
# Commands LENGTH (all same length)
COMMAND_LENGTH = 4
COMMAND_LENGTH: typing.Final[int] = 4
VERSION = 'v2.0.0'
VERSION: typing.Final[str] = 'v2.0.0'
# Valid commands
COMMAND_OPEN = b'OPEN'
COMMAND_TEST = b'TEST'
COMMAND_STAT = b'STAT' # full stats
COMMAND_INFO = b'INFO' # Basic stats, currently same as FULL
COMMAND_OPEN: typing.Final[bytes] = b'OPEN'
COMMAND_TEST: typing.Final[bytes] = b'TEST'
COMMAND_STAT: typing.Final[bytes] = b'STAT' # full stats
COMMAND_INFO: typing.Final[bytes] = b'INFO' # Basic stats, currently same as FULL
RESPONSE_ERROR_TICKET: typing.Final[bytes] = b'ERROR_TICKET'
RESPONSE_ERROR_COMMAND: typing.Final[bytes] = b'ERROR_COMMAND'
RESPONSE_ERROR_TIMEOUT: typing.Final[bytes] = b'TIMEOUT'
RESPONSE_FORBIDDEN: typing.Final[bytes] = b'FORBIDDEN'
RESPONSE_OK: typing.Final[bytes] = b'OK'
# Timeout for command
TIMEOUT_COMMAND: typing.Final[int] = 3

View File

@ -156,13 +156,16 @@ class Processes:
ns: 'Namespace',
) -> None:
if cfg.use_uvloop:
import uvloop
try:
import uvloop
if sys.version_info >= (3, 11):
with asyncio.Runner(loop_factory=uvloop.new_event_loop) as runner:
runner.run(proc(conn, cfg, ns))
else:
uvloop.install()
asyncio.run(proc(conn, cfg, ns))
if sys.version_info >= (3, 11):
with asyncio.Runner(loop_factory=uvloop.new_event_loop) as runner:
runner.run(proc(conn, cfg, ns))
else:
uvloop.install()
asyncio.run(proc(conn, cfg, ns))
except ImportError:
logger.warning('uvloop not found, using default asyncio')
else:
asyncio.run(proc(conn, cfg, ns))

View File

@ -45,11 +45,11 @@ logger = logging.getLogger(__name__)
class Proxy:
cfg: 'config.ConfigurationType'
args: 'Namespace'
ns: 'Namespace'
def __init__(self, cfg: 'config.ConfigurationType', args: 'Namespace') -> None:
def __init__(self, cfg: 'config.ConfigurationType', ns: 'Namespace') -> None:
self.cfg = cfg
self.args = args
self.ns = ns
# Method responsible of proxying requests
async def __call__(self, source: socket.socket, context: 'ssl.SSLContext') -> None:
@ -66,14 +66,14 @@ class Proxy:
logger.error('Proxy error from %s: %s', addr, e)
async def proxy(self, source: socket.socket, context: 'ssl.SSLContext') -> None:
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
# Handshake correct in this point, upgrade the connection to TSL and let
# the protocol controller do the rest
# Upgrade connection to SSL, and use asyncio to handle the rest
try:
protocol: tunnel.TunnelProtocol
# (connect accepted loop not present on AbastractEventLoop definition < 3.10)
# (connect accepted loop not present on AbastractEventLoop definition < 3.10), that's why we use ignore
(_, protocol) = await loop.connect_accepted_socket( # type: ignore
lambda: tunnel.TunnelProtocol(self), source, ssl=context
)

View File

@ -66,6 +66,8 @@ class TunnelProtocol(asyncio.Protocol):
stats_manager: stats.Stats
# counter
counter: stats.StatsSingleCounter
# If there is a timeout task running
timeout_task: typing.Optional[asyncio.Task]
def __init__(
self, owner: 'proxy.Proxy', other_side: typing.Optional['TunnelProtocol'] = None
@ -79,7 +81,7 @@ class TunnelProtocol(asyncio.Protocol):
self.runner = self.do_proxy
else:
self.other_side = self
self.stats_manager = stats.Stats(owner.args)
self.stats_manager = stats.Stats(owner.ns)
self.counter = self.stats_manager.as_sent_counter()
self.runner = self.do_command
@ -90,6 +92,10 @@ class TunnelProtocol(asyncio.Protocol):
self.owner = owner
self.source = ('', 0)
self.destination = ('', 0)
self.timeout_task = None
# Set starting timeout task, se we dont get hunged on connections without data
self.set_timeout(consts.TIMEOUT_COMMAND)
def process_open(self) -> None:
# Open Command has the ticket behind it
@ -106,7 +112,7 @@ class TunnelProtocol(asyncio.Protocol):
# clean up the command
self.cmd = b''
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
async def open_other_side() -> None:
try:
@ -115,7 +121,7 @@ class TunnelProtocol(asyncio.Protocol):
)
except Exception as e:
logger.error('ERROR %s', e.args[0] if e.args else e)
self.transport.write(b'ERROR_TICKET')
self.transport.write(consts.RESPONSE_ERROR_TICKET)
self.transport.close() # And force close
return
@ -160,7 +166,7 @@ class TunnelProtocol(asyncio.Protocol):
# Check valid source ip
if self.transport.get_extra_info('peername')[0] not in self.owner.cfg.allow:
# Invalid source
self.transport.write(b'FORBIDDEN')
self.transport.write(consts.RESPONSE_FORBIDDEN)
return
# Check password
@ -171,10 +177,10 @@ class TunnelProtocol(asyncio.Protocol):
if passwd.decode(errors='ignore') != self.owner.cfg.secret:
# Invalid password
self.transport.write(b'FORBIDDEN')
self.transport.write(consts.RESPONSE_FORBIDDEN)
return
data = stats.GlobalStats.get_stats(self.owner.args)
data = stats.GlobalStats.get_stats(self.owner.ns)
for v in data:
logger.debug('SENDING %s', v)
@ -184,8 +190,37 @@ class TunnelProtocol(asyncio.Protocol):
finally:
self.close_connection()
async def timeout(self, wait: int) -> None:
try:
await asyncio.sleep(wait)
logger.error('TIMEOUT FROM %s', self.pretty_source())
self.transport.write(consts.RESPONSE_ERROR_TIMEOUT)
self.close_connection()
except asyncio.CancelledError:
pass
def set_timeout(self, wait: int) -> None:
"""Set a timeout for this connection.
If reached, the connection will be closed.
Args:
wait (int): Timeout in seconds
"""
if self.timeout_task:
self.timeout_task.cancel()
self.timeout_task = asyncio.create_task(self.timeout(wait))
def clean_timeout(self) -> None:
"""Clean the timeout task if any."""
if self.timeout_task:
self.timeout_task.cancel()
self.timeout_task = None
def do_command(self, data: bytes) -> None:
self.clean_timeout()
self.cmd += data
# Ensure we don't get a timeout
if len(self.cmd) >= consts.COMMAND_LENGTH:
logger.info('CONNECT FROM %s', self.pretty_source())
@ -195,7 +230,7 @@ class TunnelProtocol(asyncio.Protocol):
self.process_open()
elif command == consts.COMMAND_TEST:
logger.info('COMMAND: TEST')
self.transport.write(b'OK')
self.transport.write(consts.RESPONSE_OK)
self.close_connection()
return
elif command in (consts.COMMAND_STAT, consts.COMMAND_INFO):
@ -206,9 +241,11 @@ class TunnelProtocol(asyncio.Protocol):
raise Exception('Invalid command')
except Exception:
logger.error('ERROR from %s', self.pretty_source())
self.transport.write(b'ERROR_COMMAND')
self.transport.write(consts.RESPONSE_ERROR_COMMAND)
self.close_connection()
return
else:
self.set_timeout(consts.TIMEOUT_COMMAND)
# if not enough data to process command, wait for more

View File

@ -20,14 +20,20 @@ lognumber = 3
# Listen address. Defaults to 0.0.0.0
address = 0.0.0.0
# Number of workers. Defaults to 0 (means "as much as cores")
workers = 2
# Listening port
port = 7777
# If force ipv6 listen, defaults to false
# Note: if listen address is an ipv6 address, this will be forced to true
ipv6 = false
# Number of workers. Defaults to 0 (means "as much as cores")
workers = 2
# SSL Related parameters.
ssl_certificate = /etc/certs/server.pem
# Key can be included on certificate file, so this is optional
ssl_certificate_key = /etc/certs/key.pem
# ssl_ciphers and ssl_dhparam are optional.
ssl_ciphers = ECDHE-RSA-AES256-GCM-SHA512:DHE-RSA-AES256-GCM-SHA512:ECDHE-RSA-AES256-GCM-SHA384:DHE-RSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-SHA384
@ -40,6 +46,11 @@ ssl_dhparam = /etc/certs/dhparam.pem
# https://www.example.com:14333/uds/rest/tunnel/ticket
uds_server = http://172.27.0.1:8000/uds/rest/tunnel/ticket
uds_token = eBCeFxTBw1IKXCqq-RlncshwWIfrrqxc8y5nehqiqMtRztwD
# Defaults to 10 seconds
# uds_timeout = 10
# If verify ssl certificate on uds server. Defaults to true
# uds_verify_ssl = true
# Secret to get access to admin commands (Currently only stats commands). No default for this.
# Admin commands and only allowed from "allow" ips
@ -50,3 +61,7 @@ secret = MySecret
# Only use IPs, no networks allowed
# defaults to localhost (change if listen address is different from 0.0.0.0)
allow = 127.0.0.1
# If use uvloop as event loop. Defaults to true
# use_uvloop = true

View File

@ -133,7 +133,11 @@ async def tunnel_proc_async(
# Generate SSL context
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
context.load_cert_chain(cfg.ssl_certificate, cfg.ssl_certificate_key)
if cfg.ssl_certificate_key:
context.load_cert_chain(cfg.ssl_certificate, cfg.ssl_certificate_key)
else:
context.load_cert_chain(cfg.ssl_certificate)
if cfg.ssl_ciphers:
context.set_ciphers(cfg.ssl_ciphers)
@ -155,7 +159,7 @@ async def tunnel_proc_async(
# create task for server
tasks.append(asyncio.create_task(run_server()))
while tasks:
while tasks and not do_stop:
to_wait = tasks[:] # Get a copy of the list, and clean the original
# Wait for tasks to finish
done, _ = await asyncio.wait(to_wait, return_when=asyncio.FIRST_COMPLETED)
@ -165,6 +169,13 @@ async def tunnel_proc_async(
if task.exception():
logger.exception('TUNNEL ERROR')
# If any task is still running, cancel it
for task in tasks:
task.cancel()
# Wait for all tasks to finish
await asyncio.gather(*tasks, return_exceptions=True)
logger.info('PROCESS %s stopped', os.getpid())
def process_connection(

View File

@ -57,12 +57,12 @@ lognumber = {lognumber}
# Listen address. Defaults to 0.0.0.0
address = {address}
# Listen port. Defaults to 443
port = {port}
# Number of workers. Defaults to 0 (means "as much as cores")
workers = {workers}
# Listening port
port = 7777
# SSL Related parameters.
ssl_certificate = {ssl_certificate}
ssl_certificate_key = {ssl_certificate_key}
@ -89,6 +89,8 @@ secret = {secret}
# Only use IPs, no networks allowed
# defaults to localhost (change if listen address is different from 0.0.0.0)
allow = {allow}
use_uvloop = {use_uvloop}
'''
def get_config(**overrides) -> typing.Tuple[typing.Mapping[str, typing.Any], config.ConfigurationType]:
@ -100,6 +102,8 @@ def get_config(**overrides) -> typing.Tuple[typing.Mapping[str, typing.Any], con
'logsize': random.randint(0, 100), # Random log size
'lognumber': random.randint(0, 100), # Random log number
'address': f'{random.randint(0, 255)}.{random.randint(0, 255)}.{random.randint(0, 255)}.{random.randint(0, 255)}', # Random address
'port': random.randint(0, 65535), # Random port
'ipv6': random.choice([True, False]), # Random ipv6
'workers': random.randint(1, 100), # Random workers, 0 will return as many as cpu cores
'ssl_certificate': f'/tmp/uds_tunnel_{random.randint(0, 100)}.crt', # Random ssl certificate
'ssl_certificate_key': f'/tmp/uds_tunnel_{random.randint(0, 100)}.key', # Random ssl certificate key
@ -111,6 +115,7 @@ def get_config(**overrides) -> typing.Tuple[typing.Mapping[str, typing.Any], con
'uds_verify_ssl': random.choice([True, False]), # Random verify uds ssl
'secret': f'secret{random.randint(0, 100)}', # Random secret
'allow': f'{random.randint(0, 255)}.0.0.0', # Random allow
'use_uvloop': random.choice([True, False]), # Random use uvloop
}
values.update(overrides)
config_file = io.StringIO(TEST_CONFIG.format(**values))

View File

@ -32,31 +32,48 @@ import typing
import string
import random
import aiohttp
import asyncio
import contextlib
import socket
import ssl
import logging
from unittest import IsolatedAsyncioTestCase, mock
from uds_tunnel import proxy, tunnel, consts
from . import fixtures
from .utils import tools
from .utils import tools, certs
if typing.TYPE_CHECKING:
from uds_tunnel import config
logger = logging.getLogger(__name__)
NOTIFY_TICKET = '0123456789cdef01456789abcdebcdef0123456789abcdef'
UDS_GET_TICKET_RESPONSE = {
'host': '127.0.0.1',
'port': 54876,
UDS_GET_TICKET_RESPONSE = lambda host, port: {
'host': host,
'port': port,
'notify': NOTIFY_TICKET,
}
CALLER_HOST = ('host', 12345)
REMOTE_HOST = ('127.0.0.1', 54876)
def uds_response(_, ticket: bytes, msg: str, queryParams: typing.Optional[typing.Mapping[str, str]] = None) -> typing.Dict[str, typing.Any]:
def uds_response(
_,
ticket: bytes,
msg: str,
queryParams: typing.Optional[typing.Mapping[str, str]] = None,
) -> typing.Dict[str, typing.Any]:
if msg == 'stop':
return {}
return UDS_GET_TICKET_RESPONSE
return UDS_GET_TICKET_RESPONSE(*REMOTE_HOST)
class TestTunnel(IsolatedAsyncioTestCase):
async def test_get_ticket_from_uds(self) -> None:
async def test_get_ticket_from_uds_broker(self) -> None:
_, cfg = fixtures.get_config()
# Test some invalid tickets
# Valid ticket are consts.TICKET_LENGTH bytes long, and must be A-Z, a-z, 0-9
@ -65,7 +82,6 @@ class TestTunnel(IsolatedAsyncioTestCase):
new_callable=tools.AsyncMock,
) as m:
m.side_effect = uds_response
#m.return_value = UDS_GET_TICKET_RESPONSE
for i in range(0, 100):
ticket = ''.join(
random.choices(
@ -86,7 +102,7 @@ class TestTunnel(IsolatedAsyncioTestCase):
cfg, ticket.encode(), CALLER_HOST
)
# Ensure data returned is correct {host, port, notify} from mock
self.assertEqual(ret_value, m.return_value)
self.assertEqual(ret_value, UDS_GET_TICKET_RESPONSE(*REMOTE_HOST))
# Ensure mock was called with correct parameters
print(m.call_args)
# Check calling parameters, first one is the config, second one is the ticket, third one is the caller host
@ -102,7 +118,7 @@ class TestTunnel(IsolatedAsyncioTestCase):
# mock should have been called 100 times
self.assertEqual(m.call_count, 100)
async def test_notify_end_to_uds(self) -> None:
async def test_notify_end_to_uds_broker(self) -> None:
_, cfg = fixtures.get_config()
with mock.patch(
'uds_tunnel.tunnel.TunnelProtocol._readFromUDS',
@ -130,7 +146,7 @@ class TestTunnel(IsolatedAsyncioTestCase):
# mock should have been called 100 times
self.assertEqual(m.call_count, 100)
async def test_read_from_uds(self) -> None:
async def test_read_from_uds_broker(self) -> None:
# Generate a listening http server for testing UDS
# Tesst fine responses:
for use_ssl in (True, False):
@ -142,7 +158,9 @@ class TestTunnel(IsolatedAsyncioTestCase):
f'http{"s" if use_ssl else ""}://127.0.0.1:{server.port}/'
)
_, cfg = fixtures.get_config(
uds_server=fake_uds_server, uds_verify_ssl=False
uds_server=fake_uds_server,
uds_verify_ssl=False,
listen_protocol='http',
)
self.assertEqual(
await TestTunnel.get(fake_uds_server),
@ -155,6 +173,37 @@ class TestTunnel(IsolatedAsyncioTestCase):
)
self.assertEqual(ret, {'result': 'ok'})
async def test_tunnel_invalid_command(self) -> None:
# Test invalid handshake
# data = b''
# future: asyncio.Future = asyncio.Future()
# def callback(ldata: bytes) -> None:
# nonlocal data
# data += ldata
# future.set_result(True)
# Send invalid commands and see what happens
# Commands are 4 bytes length, try with less and more invalid commands
for i in range(0, 100, 10):
# Set timeout to 1 seconds
bad_cmd = bytes(random.randint(0, 255) for _ in range(i)) # Some garbage
consts.TIMEOUT_COMMAND = 0.1 # type: ignore # timeout is a final variable, but we need to change it for testing speed
logger.info(f'Testing invalid command with {bad_cmd!r}')
async with TestTunnel.create_test_tunnel(lambda x: None) as cfg:
# Open connection to tunnel
async with TestTunnel.open_tunnel(cfg) as (reader, writer):
# Send data
writer.write(bad_cmd)
await writer.drain()
# Wait for response
readed = await reader.read(1024)
# Should return consts.ERROR_COMMAND or consts.ERROR_TIMEOUT
if len(bad_cmd) < 4:
self.assertEqual(readed, consts.RESPONSE_ERROR_TIMEOUT)
else:
self.assertEqual(readed, consts.RESPONSE_ERROR_COMMAND)
# Helpers
@staticmethod
async def get(url: str) -> str:
@ -165,3 +214,74 @@ class TestTunnel(IsolatedAsyncioTestCase):
async with session.get(url, **options) as r:
r.raise_for_status()
return await r.text()
@staticmethod
async def create_tunnel_server(
cfg: 'config.ConfigurationType', context: 'ssl.SSLContext'
) -> 'asyncio.Server':
# Create fake proxy
proxy = mock.MagicMock()
proxy.cfg = cfg
proxy.ns = mock.MagicMock()
proxy.ns.current = 0
proxy.ns.total = 0
proxy.ns.sent = 0
proxy.ns.recv = 0
proxy.counter = 0
loop = asyncio.get_running_loop()
# Create an asyncio listen socket on cfg.listen_host:cfg.listen_port
return await loop.create_server(
lambda: tunnel.TunnelProtocol(proxy),
cfg.listen_address,
cfg.listen_port,
ssl=context,
family=socket.AF_INET6 if cfg.listen_ipv6 or ':' in cfg.listen_address else socket.AF_INET,
)
@staticmethod
@contextlib.asynccontextmanager
async def create_test_tunnel(callback: typing.Callable[[bytes], None]) -> typing.AsyncGenerator['config.ConfigurationType', None]:
# Generate a listening server for testing tunnel
# Prepare the end of the tunnel
async with tools.AsyncTCPServer(port=54876, callback=callback) as server:
# Create a tunnel to localhost 13579
# SSl cert for tunnel server
with certs.ssl_context(server.host) as (ssl_ctx, _):
_, cfg = fixtures.get_config(
address=server.host,
port=7777,
)
with mock.patch(
'uds_tunnel.tunnel.TunnelProtocol._readFromUDS',
new_callable=tools.AsyncMock,
) as m:
m.return_value = UDS_GET_TICKET_RESPONSE(server.host, server.port)
tunnel_server = await TestTunnel.create_tunnel_server(cfg, ssl_ctx)
yield cfg
tunnel_server.close()
await tunnel_server.wait_closed()
@staticmethod
@contextlib.asynccontextmanager
async def open_tunnel(
cfg: 'config.ConfigurationType',
) -> typing.AsyncGenerator[typing.Tuple[asyncio.StreamReader, asyncio.StreamWriter], None]:
""" opens an ssl socket to the tunnel server
"""
if cfg.listen_ipv6 or ':' in cfg.listen_address:
family = socket.AF_INET6
else:
family = socket.AF_INET
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
context.check_hostname = False
context.verify_mode = ssl.CERT_NONE
reader, writer = await asyncio.open_connection(
cfg.listen_address, cfg.listen_port, ssl=context, family=family
)
yield reader, writer
writer.close()
await writer.wait_closed()

View File

@ -1,8 +1,12 @@
import secrets
import random
from datetime import datetime, timedelta
import datetime
import tempfile
import ipaddress
import typing
import ssl
import os
import contextlib
from cryptography import x509
from cryptography.x509.oid import NameOID
@ -25,7 +29,7 @@ def selfSignedCert(ip: str) -> typing.Tuple[str, str, str]:
san = x509.SubjectAlternativeName([x509.IPAddress(ipaddress.ip_address(ip))])
basic_contraints = x509.BasicConstraints(ca=True, path_length=0)
now = datetime.utcnow()
now = datetime.datetime.utcnow()
cert = (
x509.CertificateBuilder()
.subject_name(name)
@ -33,7 +37,7 @@ def selfSignedCert(ip: str) -> typing.Tuple[str, str, str]:
.public_key(key.public_key())
.serial_number(random.SystemRandom().randint(0, 1 << 64))
.not_valid_before(now)
.not_valid_after(now + timedelta(days=10 * 365))
.not_valid_after(now + datetime.timedelta(days=10 * 365))
.add_extension(basic_contraints, False)
.add_extension(san, False)
.sign(key, hashes.SHA256(), default_backend())
@ -50,3 +54,47 @@ def selfSignedCert(ip: str) -> typing.Tuple[str, str, str]:
cert.public_bytes(encoding=serialization.Encoding.PEM).decode(),
password,
)
def sslContext(ip: str) -> typing.Tuple[ssl.SSLContext, str, str]:
"""Returns an ssl context an the certificate & password for an ip
Args:
ip (str): Ip for subject name
Returns:
typing.Tuple[ssl.SSLContext, str, str]: ssl context, certificate file and password
"""
# First, create server cert and key on temp dir
tmpdir = tempfile.gettempdir()
cert, key, password = selfSignedCert('127.0.0.1')
cert_file = f'{tmpdir}/tmp_cert.pem'
with open(cert_file, 'w') as f:
f.write(key)
f.write(cert)
# Create SSL context
ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
ssl_ctx.options |= ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1
ssl_ctx.load_cert_chain(certfile=f'{tmpdir}/tmp_cert.pem', password=password)
ssl_ctx.check_hostname = False
ssl_ctx.set_ciphers('ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384')
return ssl_ctx, cert_file, password
@contextlib.contextmanager
def ssl_context(ip: str) -> typing.Generator[typing.Tuple[ssl.SSLContext, str], None, None]:
"""Returns an ssl context for an ip
Args:
ip (str): Ip for subject name
Returns:
ssl.SSLContext: ssl context
"""
# First, create server cert and key on temp dir
ssl_ctx, cert_file, password = sslContext(ip)
yield ssl_ctx, cert_file
# Remove cert file
os.remove(cert_file)

View File

@ -37,6 +37,7 @@ from unittest import mock
from . import certs
class AsyncMock(mock.MagicMock):
async def __call__(self, *args, **kwargs):
return super().__call__(*args, **kwargs)
@ -44,42 +45,37 @@ class AsyncMock(mock.MagicMock):
# simple async http server, will return 200 OK with the request path as body
class AsyncHttpServer:
host: str
port: int
_server: typing.Optional[asyncio.AbstractServer]
_response: typing.Optional[bytes]
_ssl_ctx: typing.Optional[ssl.SSLContext]
_ssl_cert_file: typing.Optional[str]
def __init__(
self, port: int, *, response: typing.Optional[bytes] = None, use_ssl: bool = False,
self,
port: int,
*,
response: typing.Optional[bytes] = None,
use_ssl: bool = False,
host: str = '127.0.0.1' # ip
):
) -> None:
self.host = host
self.port = port
self._server = None
self._response = response
if use_ssl:
# First, create server cert and key on temp dir
tmpdir = tempfile.gettempdir()
cert, key, password = certs.selfSignedCert('127.0.0.1')
with open(f'{tmpdir}/tmp_cert.pem', 'w') as f:
f.write(key)
f.write(cert)
# Create SSL context
self._ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
self._ssl_ctx.options |= ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1
self._ssl_ctx.load_cert_chain(certfile=f'{tmpdir}/tmp_cert.pem', password=password)
self._ssl_ctx.check_hostname = False
self._ssl_ctx.set_ciphers(
'ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384'
)
self._ssl_ctx, self._ssl_cert_file, pwd = certs.sslContext(host)
else:
self._ssl_ctx = None
self._ssl_cert_file = None
# on end, remove certs
def __del__(self):
tmpdir = tempfile.gettempdir()
# os.remove(f'{tmpdir}/tmp_cert.pem')
def __del__(self) -> None:
if self._ssl_cert_file:
os.unlink(self._ssl_cert_file)
async def _handle(self, reader, writer):
async def _handle(self, reader, writer) -> None:
data = await reader.read(2048)
path: bytes = data.split()[1]
if self._response is not None:
@ -90,20 +86,69 @@ class AsyncHttpServer:
)
await writer.drain()
async def __aenter__(self):
async def __aenter__(self) -> 'AsyncHttpServer':
if self._ssl_ctx is not None:
self._server = await asyncio.start_server(
self._handle, '127.0.0.1', self.port, ssl=self._ssl_ctx
self._handle, self.host, self.port, ssl=self._ssl_ctx
)
else:
self._server = await asyncio.start_server(
self._handle, '127.0.0.1', self.port
self._handle, self.host, self.port
)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
if self._server is not None:
self._server.close()
await self._server.wait_closed()
self._server = None
class AsyncTCPServer:
host: str
port: int
_server: typing.Optional[asyncio.AbstractServer]
_response: typing.Optional[bytes]
_callback: typing.Optional[typing.Callable[[bytes], None]]
def __init__(
self,
port: int,
*,
response: typing.Optional[bytes] = None,
host: str = '127.0.0.1', # ip
callback: typing.Optional[typing.Callable[[bytes], None]] = None
) -> None:
self.host = host
self.port = port
self._server = None
self._response = response
self._callback = callback
self.data = b''
async def _handle(self, reader, writer) -> None:
data = await reader.read(2048)
if self._callback:
self._callback(data)
if self._response is not None:
data = self._response
else:
data = b'sample data'
writer.write(data)
await writer.drain()
async def __aenter__(self) -> 'AsyncTCPServer':
self._server = await asyncio.start_server(
self._handle,
self.host,
self.port
)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
if self._server is not None:
self._server.close()
await self._server.wait_closed()
self._server = None