mirror of
https://github.com/dkmstr/openuds.git
synced 2025-03-11 00:58:39 +03:00
added basic structure for testing tunnel server and some tests
This commit is contained in:
parent
406f32c2fa
commit
081dfc9995
@ -38,16 +38,29 @@ import threading
|
||||
import select
|
||||
import typing
|
||||
import logging
|
||||
import enum
|
||||
|
||||
from . import tools
|
||||
|
||||
HANDSHAKE_V1 = b'\x5AMGB\xA5\x01\x00'
|
||||
BUFFER_SIZE = 1024 * 16 # Max buffer length
|
||||
DEBUG = True
|
||||
LISTEN_ADDRESS = '0.0.0.0' if DEBUG else '127.0.0.1'
|
||||
|
||||
BUFFER_SIZE: typing.Final[int] = 1024 * 16 # Max buffer length
|
||||
LISTEN_ADDRESS: typing.Final[str] = '0.0.0.0' if DEBUG else '127.0.0.1'
|
||||
LISTEN_ADDRESS_V6: typing.Final[str] = '::' if DEBUG else '::1'
|
||||
|
||||
# ForwarServer states
|
||||
TUNNEL_LISTENING, TUNNEL_OPENING, TUNNEL_PROCESSING, TUNNEL_ERROR = 0, 1, 2, 3
|
||||
class ForwardState(enum.IntEnum):
|
||||
TUNNEL_LISTENING = 0
|
||||
TUNNEL_OPENING = 1
|
||||
TUNNEL_PROCESSING = 2
|
||||
TUNNEL_ERROR = 3
|
||||
|
||||
# Some constants strings for protocol
|
||||
HANDSHAKE_V1: typing.Final[bytes] = b'\x5AMGB\xA5\x01\x00'
|
||||
CMD_TEST: typing.Final[bytes] = b'TEST'
|
||||
CMD_OPEN: typing.Final[bytes] = b'OPEN'
|
||||
|
||||
RESPONSE_OK: typing.Final[bytes] = b'OK'
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -57,6 +70,7 @@ class ForwardServer(socketserver.ThreadingTCPServer):
|
||||
allow_reuse_address = True
|
||||
|
||||
remote: typing.Tuple[str, int]
|
||||
remote_ipv6: bool
|
||||
ticket: str
|
||||
stop_flag: threading.Event
|
||||
can_stop: bool
|
||||
@ -64,7 +78,9 @@ class ForwardServer(socketserver.ThreadingTCPServer):
|
||||
timer: typing.Optional[threading.Timer]
|
||||
check_certificate: bool
|
||||
current_connections: int
|
||||
status: int
|
||||
status: ForwardState
|
||||
|
||||
address_family = socket.AF_INET
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -73,14 +89,21 @@ class ForwardServer(socketserver.ThreadingTCPServer):
|
||||
timeout: int = 0,
|
||||
local_port: int = 0,
|
||||
check_certificate: bool = True,
|
||||
ipv6_listen: bool = False,
|
||||
ipv6_remote: bool = False,
|
||||
) -> None:
|
||||
|
||||
local_port = local_port or random.randrange(33000, 53000)
|
||||
|
||||
if ipv6_listen:
|
||||
self.address_family = socket.AF_INET6
|
||||
|
||||
super().__init__(
|
||||
server_address=(LISTEN_ADDRESS, local_port), RequestHandlerClass=Handler
|
||||
server_address=(LISTEN_ADDRESS if ipv6_listen else LISTEN_ADDRESS_V6, local_port),
|
||||
RequestHandlerClass=Handler,
|
||||
)
|
||||
self.remote = remote
|
||||
self.remote_ipv6 = ipv6_remote or ':' in remote[0] # if ':' in remote address, it's ipv6 (port is [1])
|
||||
self.ticket = ticket
|
||||
# Negative values for timeout, means "accept always connections"
|
||||
# "but if no connection is stablished on timeout (positive)"
|
||||
@ -90,7 +113,7 @@ class ForwardServer(socketserver.ThreadingTCPServer):
|
||||
self.stop_flag = threading.Event() # False initial
|
||||
self.current_connections = 0
|
||||
|
||||
self.status = TUNNEL_LISTENING
|
||||
self.status = ForwardState.TUNNEL_LISTENING
|
||||
self.can_stop = False
|
||||
|
||||
timeout = abs(timeout) or 60
|
||||
@ -109,7 +132,7 @@ class ForwardServer(socketserver.ThreadingTCPServer):
|
||||
self.shutdown()
|
||||
|
||||
def connect(self) -> ssl.SSLSocket:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as rsocket:
|
||||
with socket.socket(socket.AF_INET6 if self.remote_ipv6 else socket.AF_INET, socket.SOCK_STREAM) as rsocket:
|
||||
logger.info('CONNECT to %s', self.remote)
|
||||
|
||||
rsocket.connect(self.remote)
|
||||
@ -134,16 +157,16 @@ class ForwardServer(socketserver.ThreadingTCPServer):
|
||||
return context.wrap_socket(rsocket, server_hostname=self.remote[0])
|
||||
|
||||
def check(self) -> bool:
|
||||
if self.status == TUNNEL_ERROR:
|
||||
if self.status == ForwardState.TUNNEL_ERROR:
|
||||
return False
|
||||
|
||||
logger.debug('Checking tunnel availability')
|
||||
|
||||
try:
|
||||
with self.connect() as ssl_socket:
|
||||
ssl_socket.sendall(b'TEST')
|
||||
ssl_socket.sendall(CMD_TEST)
|
||||
resp = ssl_socket.recv(2)
|
||||
if resp != b'OK':
|
||||
if resp != RESPONSE_OK:
|
||||
raise Exception({'Invalid tunnelresponse: {resp}'})
|
||||
logger.debug('Tunnel is available!')
|
||||
return True
|
||||
@ -173,11 +196,11 @@ class Handler(socketserver.BaseRequestHandler):
|
||||
|
||||
# server: ForwardServer
|
||||
def handle(self) -> None:
|
||||
self.server.status = TUNNEL_OPENING
|
||||
self.server.status = ForwardState.TUNNEL_OPENING
|
||||
|
||||
# If server processing is over time
|
||||
# If server new connections processing are over time...
|
||||
if self.server.stoppable:
|
||||
self.server.status = TUNNEL_ERROR
|
||||
self.server.status = ForwardState.TUNNEL_ERROR
|
||||
logger.info('Rejected timedout connection')
|
||||
self.request.close() # End connection without processing it
|
||||
return
|
||||
@ -189,10 +212,10 @@ class Handler(socketserver.BaseRequestHandler):
|
||||
logger.debug('Ticket %s', self.server.ticket)
|
||||
with self.server.connect() as ssl_socket:
|
||||
# Send handhshake + command + ticket
|
||||
ssl_socket.sendall(b'OPEN' + self.server.ticket.encode())
|
||||
ssl_socket.sendall(CMD_OPEN + self.server.ticket.encode())
|
||||
# Check response is OK
|
||||
data = ssl_socket.recv(2)
|
||||
if data != b'OK':
|
||||
if data != RESPONSE_OK:
|
||||
data += ssl_socket.recv(128)
|
||||
raise Exception(
|
||||
f'Error received: {data.decode(errors="ignore")}'
|
||||
@ -202,7 +225,7 @@ class Handler(socketserver.BaseRequestHandler):
|
||||
self.process(remote=ssl_socket)
|
||||
except Exception as e:
|
||||
logger.error(f'Error connecting to {self.server.remote!s}: {e!s}')
|
||||
self.server.status = TUNNEL_ERROR
|
||||
self.server.status = ForwardState.TUNNEL_ERROR
|
||||
self.server.stop()
|
||||
finally:
|
||||
self.server.current_connections -= 1
|
||||
@ -212,7 +235,7 @@ class Handler(socketserver.BaseRequestHandler):
|
||||
|
||||
# Processes data forwarding
|
||||
def process(self, remote: ssl.SSLSocket):
|
||||
self.server.status = TUNNEL_PROCESSING
|
||||
self.server.status = ForwardState.TUNNEL_PROCESSING
|
||||
logger.debug('Processing tunnel with ticket %s', self.server.ticket)
|
||||
# Process data until stop requested or connection closed
|
||||
try:
|
||||
|
@ -50,6 +50,7 @@ class ConfigurationType(typing.NamedTuple):
|
||||
|
||||
listen_address: str
|
||||
listen_port: int
|
||||
listen_ipv6: bool
|
||||
|
||||
workers: int
|
||||
|
||||
@ -117,9 +118,10 @@ def read(
|
||||
log_number=int(uds.get('lognumber', '3')),
|
||||
listen_address=uds.get('address', '0.0.0.0'),
|
||||
listen_port=int(uds.get('port', '443')),
|
||||
listen_ipv6=uds.get('ipv6', 'false').lower() == 'true',
|
||||
workers=int(uds.get('workers', '0')) or multiprocessing.cpu_count(),
|
||||
ssl_certificate=uds['ssl_certificate'],
|
||||
ssl_certificate_key=uds['ssl_certificate_key'],
|
||||
ssl_certificate_key=uds.get('ssl_certificate_key', ''),
|
||||
ssl_ciphers=uds.get('ssl_ciphers'),
|
||||
ssl_dhparam=uds.get('ssl_dhparam'),
|
||||
uds_server=uds_server,
|
||||
|
@ -28,34 +28,42 @@
|
||||
'''
|
||||
Author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
'''
|
||||
import typing
|
||||
|
||||
DEBUG = True
|
||||
|
||||
if DEBUG:
|
||||
CONFIGFILE = 'udstunnel.conf'
|
||||
LOGFORMAT = '%(levelname)s %(asctime)s %(message)s'
|
||||
else:
|
||||
CONFIGFILE = '/etc/udstunnel.conf'
|
||||
LOGFORMAT = '%(levelname)s %(asctime)s %(message)s'
|
||||
CONFIGFILE: typing.Final[str] = '/etc/udstunnel.conf' if not DEBUG else 'udstunnel.conf'
|
||||
LOGFORMAT: typing.Final[str] = '%(levelname)s %(asctime)s %(message)s' if not DEBUG else '%(levelname)s %(asctime)s %(message)s'
|
||||
|
||||
# MAX Length of read buffer for proxyed requests
|
||||
BUFFER_SIZE = 1024 * 16
|
||||
BUFFER_SIZE: typing.Final[int] = 1024 * 16
|
||||
# Handshake for conversation start
|
||||
HANDSHAKE_V1 = b'\x5AMGB\xA5\x01\x00'
|
||||
HANDSHAKE_V1: typing.Final[bytes] = b'\x5AMGB\xA5\x01\x00'
|
||||
# Ticket length
|
||||
TICKET_LENGTH = 48
|
||||
TICKET_LENGTH: typing.Final[int] = 48
|
||||
# Max Admin password length (stats basically right now)
|
||||
PASSWORD_LENGTH = 64
|
||||
PASSWORD_LENGTH: typing.Final[int] = 64
|
||||
# Bandwidth calc time lapse
|
||||
BANDWIDTH_TIME = 10
|
||||
BANDWIDTH_TIME: typing.Final[int] = 10
|
||||
|
||||
# Commands LENGTH (all same length)
|
||||
COMMAND_LENGTH = 4
|
||||
COMMAND_LENGTH: typing.Final[int] = 4
|
||||
|
||||
VERSION = 'v2.0.0'
|
||||
VERSION: typing.Final[str] = 'v2.0.0'
|
||||
|
||||
# Valid commands
|
||||
COMMAND_OPEN = b'OPEN'
|
||||
COMMAND_TEST = b'TEST'
|
||||
COMMAND_STAT = b'STAT' # full stats
|
||||
COMMAND_INFO = b'INFO' # Basic stats, currently same as FULL
|
||||
COMMAND_OPEN: typing.Final[bytes] = b'OPEN'
|
||||
COMMAND_TEST: typing.Final[bytes] = b'TEST'
|
||||
COMMAND_STAT: typing.Final[bytes] = b'STAT' # full stats
|
||||
COMMAND_INFO: typing.Final[bytes] = b'INFO' # Basic stats, currently same as FULL
|
||||
|
||||
RESPONSE_ERROR_TICKET: typing.Final[bytes] = b'ERROR_TICKET'
|
||||
RESPONSE_ERROR_COMMAND: typing.Final[bytes] = b'ERROR_COMMAND'
|
||||
RESPONSE_ERROR_TIMEOUT: typing.Final[bytes] = b'TIMEOUT'
|
||||
RESPONSE_FORBIDDEN: typing.Final[bytes] = b'FORBIDDEN'
|
||||
|
||||
RESPONSE_OK: typing.Final[bytes] = b'OK'
|
||||
|
||||
|
||||
# Timeout for command
|
||||
TIMEOUT_COMMAND: typing.Final[int] = 3
|
@ -156,13 +156,16 @@ class Processes:
|
||||
ns: 'Namespace',
|
||||
) -> None:
|
||||
if cfg.use_uvloop:
|
||||
import uvloop
|
||||
try:
|
||||
import uvloop
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
with asyncio.Runner(loop_factory=uvloop.new_event_loop) as runner:
|
||||
runner.run(proc(conn, cfg, ns))
|
||||
else:
|
||||
uvloop.install()
|
||||
asyncio.run(proc(conn, cfg, ns))
|
||||
if sys.version_info >= (3, 11):
|
||||
with asyncio.Runner(loop_factory=uvloop.new_event_loop) as runner:
|
||||
runner.run(proc(conn, cfg, ns))
|
||||
else:
|
||||
uvloop.install()
|
||||
asyncio.run(proc(conn, cfg, ns))
|
||||
except ImportError:
|
||||
logger.warning('uvloop not found, using default asyncio')
|
||||
else:
|
||||
asyncio.run(proc(conn, cfg, ns))
|
||||
|
@ -45,11 +45,11 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class Proxy:
|
||||
cfg: 'config.ConfigurationType'
|
||||
args: 'Namespace'
|
||||
ns: 'Namespace'
|
||||
|
||||
def __init__(self, cfg: 'config.ConfigurationType', args: 'Namespace') -> None:
|
||||
def __init__(self, cfg: 'config.ConfigurationType', ns: 'Namespace') -> None:
|
||||
self.cfg = cfg
|
||||
self.args = args
|
||||
self.ns = ns
|
||||
|
||||
# Method responsible of proxying requests
|
||||
async def __call__(self, source: socket.socket, context: 'ssl.SSLContext') -> None:
|
||||
@ -66,14 +66,14 @@ class Proxy:
|
||||
logger.error('Proxy error from %s: %s', addr, e)
|
||||
|
||||
async def proxy(self, source: socket.socket, context: 'ssl.SSLContext') -> None:
|
||||
loop = asyncio.get_event_loop()
|
||||
loop = asyncio.get_running_loop()
|
||||
# Handshake correct in this point, upgrade the connection to TSL and let
|
||||
# the protocol controller do the rest
|
||||
|
||||
# Upgrade connection to SSL, and use asyncio to handle the rest
|
||||
try:
|
||||
protocol: tunnel.TunnelProtocol
|
||||
# (connect accepted loop not present on AbastractEventLoop definition < 3.10)
|
||||
# (connect accepted loop not present on AbastractEventLoop definition < 3.10), that's why we use ignore
|
||||
(_, protocol) = await loop.connect_accepted_socket( # type: ignore
|
||||
lambda: tunnel.TunnelProtocol(self), source, ssl=context
|
||||
)
|
||||
|
@ -66,6 +66,8 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
stats_manager: stats.Stats
|
||||
# counter
|
||||
counter: stats.StatsSingleCounter
|
||||
# If there is a timeout task running
|
||||
timeout_task: typing.Optional[asyncio.Task]
|
||||
|
||||
def __init__(
|
||||
self, owner: 'proxy.Proxy', other_side: typing.Optional['TunnelProtocol'] = None
|
||||
@ -79,7 +81,7 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
self.runner = self.do_proxy
|
||||
else:
|
||||
self.other_side = self
|
||||
self.stats_manager = stats.Stats(owner.args)
|
||||
self.stats_manager = stats.Stats(owner.ns)
|
||||
self.counter = self.stats_manager.as_sent_counter()
|
||||
self.runner = self.do_command
|
||||
|
||||
@ -90,6 +92,10 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
self.owner = owner
|
||||
self.source = ('', 0)
|
||||
self.destination = ('', 0)
|
||||
self.timeout_task = None
|
||||
|
||||
# Set starting timeout task, se we dont get hunged on connections without data
|
||||
self.set_timeout(consts.TIMEOUT_COMMAND)
|
||||
|
||||
def process_open(self) -> None:
|
||||
# Open Command has the ticket behind it
|
||||
@ -106,7 +112,7 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
# clean up the command
|
||||
self.cmd = b''
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
async def open_other_side() -> None:
|
||||
try:
|
||||
@ -115,7 +121,7 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error('ERROR %s', e.args[0] if e.args else e)
|
||||
self.transport.write(b'ERROR_TICKET')
|
||||
self.transport.write(consts.RESPONSE_ERROR_TICKET)
|
||||
self.transport.close() # And force close
|
||||
return
|
||||
|
||||
@ -160,7 +166,7 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
# Check valid source ip
|
||||
if self.transport.get_extra_info('peername')[0] not in self.owner.cfg.allow:
|
||||
# Invalid source
|
||||
self.transport.write(b'FORBIDDEN')
|
||||
self.transport.write(consts.RESPONSE_FORBIDDEN)
|
||||
return
|
||||
|
||||
# Check password
|
||||
@ -171,10 +177,10 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
|
||||
if passwd.decode(errors='ignore') != self.owner.cfg.secret:
|
||||
# Invalid password
|
||||
self.transport.write(b'FORBIDDEN')
|
||||
self.transport.write(consts.RESPONSE_FORBIDDEN)
|
||||
return
|
||||
|
||||
data = stats.GlobalStats.get_stats(self.owner.args)
|
||||
data = stats.GlobalStats.get_stats(self.owner.ns)
|
||||
|
||||
for v in data:
|
||||
logger.debug('SENDING %s', v)
|
||||
@ -184,8 +190,37 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
finally:
|
||||
self.close_connection()
|
||||
|
||||
async def timeout(self, wait: int) -> None:
|
||||
try:
|
||||
await asyncio.sleep(wait)
|
||||
logger.error('TIMEOUT FROM %s', self.pretty_source())
|
||||
self.transport.write(consts.RESPONSE_ERROR_TIMEOUT)
|
||||
self.close_connection()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
def set_timeout(self, wait: int) -> None:
|
||||
"""Set a timeout for this connection.
|
||||
If reached, the connection will be closed.
|
||||
|
||||
Args:
|
||||
wait (int): Timeout in seconds
|
||||
|
||||
"""
|
||||
if self.timeout_task:
|
||||
self.timeout_task.cancel()
|
||||
self.timeout_task = asyncio.create_task(self.timeout(wait))
|
||||
|
||||
def clean_timeout(self) -> None:
|
||||
"""Clean the timeout task if any."""
|
||||
if self.timeout_task:
|
||||
self.timeout_task.cancel()
|
||||
self.timeout_task = None
|
||||
|
||||
def do_command(self, data: bytes) -> None:
|
||||
self.clean_timeout()
|
||||
self.cmd += data
|
||||
# Ensure we don't get a timeout
|
||||
if len(self.cmd) >= consts.COMMAND_LENGTH:
|
||||
logger.info('CONNECT FROM %s', self.pretty_source())
|
||||
|
||||
@ -195,7 +230,7 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
self.process_open()
|
||||
elif command == consts.COMMAND_TEST:
|
||||
logger.info('COMMAND: TEST')
|
||||
self.transport.write(b'OK')
|
||||
self.transport.write(consts.RESPONSE_OK)
|
||||
self.close_connection()
|
||||
return
|
||||
elif command in (consts.COMMAND_STAT, consts.COMMAND_INFO):
|
||||
@ -206,9 +241,11 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
raise Exception('Invalid command')
|
||||
except Exception:
|
||||
logger.error('ERROR from %s', self.pretty_source())
|
||||
self.transport.write(b'ERROR_COMMAND')
|
||||
self.transport.write(consts.RESPONSE_ERROR_COMMAND)
|
||||
self.close_connection()
|
||||
return
|
||||
else:
|
||||
self.set_timeout(consts.TIMEOUT_COMMAND)
|
||||
|
||||
# if not enough data to process command, wait for more
|
||||
|
||||
|
@ -20,14 +20,20 @@ lognumber = 3
|
||||
# Listen address. Defaults to 0.0.0.0
|
||||
address = 0.0.0.0
|
||||
|
||||
# Number of workers. Defaults to 0 (means "as much as cores")
|
||||
workers = 2
|
||||
|
||||
# Listening port
|
||||
port = 7777
|
||||
|
||||
# If force ipv6 listen, defaults to false
|
||||
# Note: if listen address is an ipv6 address, this will be forced to true
|
||||
ipv6 = false
|
||||
|
||||
# Number of workers. Defaults to 0 (means "as much as cores")
|
||||
workers = 2
|
||||
|
||||
|
||||
# SSL Related parameters.
|
||||
ssl_certificate = /etc/certs/server.pem
|
||||
# Key can be included on certificate file, so this is optional
|
||||
ssl_certificate_key = /etc/certs/key.pem
|
||||
# ssl_ciphers and ssl_dhparam are optional.
|
||||
ssl_ciphers = ECDHE-RSA-AES256-GCM-SHA512:DHE-RSA-AES256-GCM-SHA512:ECDHE-RSA-AES256-GCM-SHA384:DHE-RSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-SHA384
|
||||
@ -40,6 +46,11 @@ ssl_dhparam = /etc/certs/dhparam.pem
|
||||
# https://www.example.com:14333/uds/rest/tunnel/ticket
|
||||
uds_server = http://172.27.0.1:8000/uds/rest/tunnel/ticket
|
||||
uds_token = eBCeFxTBw1IKXCqq-RlncshwWIfrrqxc8y5nehqiqMtRztwD
|
||||
# Defaults to 10 seconds
|
||||
# uds_timeout = 10
|
||||
|
||||
# If verify ssl certificate on uds server. Defaults to true
|
||||
# uds_verify_ssl = true
|
||||
|
||||
# Secret to get access to admin commands (Currently only stats commands). No default for this.
|
||||
# Admin commands and only allowed from "allow" ips
|
||||
@ -50,3 +61,7 @@ secret = MySecret
|
||||
# Only use IPs, no networks allowed
|
||||
# defaults to localhost (change if listen address is different from 0.0.0.0)
|
||||
allow = 127.0.0.1
|
||||
|
||||
|
||||
# If use uvloop as event loop. Defaults to true
|
||||
# use_uvloop = true
|
@ -133,7 +133,11 @@ async def tunnel_proc_async(
|
||||
|
||||
# Generate SSL context
|
||||
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
context.load_cert_chain(cfg.ssl_certificate, cfg.ssl_certificate_key)
|
||||
|
||||
if cfg.ssl_certificate_key:
|
||||
context.load_cert_chain(cfg.ssl_certificate, cfg.ssl_certificate_key)
|
||||
else:
|
||||
context.load_cert_chain(cfg.ssl_certificate)
|
||||
|
||||
if cfg.ssl_ciphers:
|
||||
context.set_ciphers(cfg.ssl_ciphers)
|
||||
@ -155,7 +159,7 @@ async def tunnel_proc_async(
|
||||
# create task for server
|
||||
tasks.append(asyncio.create_task(run_server()))
|
||||
|
||||
while tasks:
|
||||
while tasks and not do_stop:
|
||||
to_wait = tasks[:] # Get a copy of the list, and clean the original
|
||||
# Wait for tasks to finish
|
||||
done, _ = await asyncio.wait(to_wait, return_when=asyncio.FIRST_COMPLETED)
|
||||
@ -165,6 +169,13 @@ async def tunnel_proc_async(
|
||||
if task.exception():
|
||||
logger.exception('TUNNEL ERROR')
|
||||
|
||||
# If any task is still running, cancel it
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
|
||||
# Wait for all tasks to finish
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
logger.info('PROCESS %s stopped', os.getpid())
|
||||
|
||||
def process_connection(
|
||||
|
@ -57,12 +57,12 @@ lognumber = {lognumber}
|
||||
# Listen address. Defaults to 0.0.0.0
|
||||
address = {address}
|
||||
|
||||
# Listen port. Defaults to 443
|
||||
port = {port}
|
||||
|
||||
# Number of workers. Defaults to 0 (means "as much as cores")
|
||||
workers = {workers}
|
||||
|
||||
# Listening port
|
||||
port = 7777
|
||||
|
||||
# SSL Related parameters.
|
||||
ssl_certificate = {ssl_certificate}
|
||||
ssl_certificate_key = {ssl_certificate_key}
|
||||
@ -89,6 +89,8 @@ secret = {secret}
|
||||
# Only use IPs, no networks allowed
|
||||
# defaults to localhost (change if listen address is different from 0.0.0.0)
|
||||
allow = {allow}
|
||||
|
||||
use_uvloop = {use_uvloop}
|
||||
'''
|
||||
|
||||
def get_config(**overrides) -> typing.Tuple[typing.Mapping[str, typing.Any], config.ConfigurationType]:
|
||||
@ -100,6 +102,8 @@ def get_config(**overrides) -> typing.Tuple[typing.Mapping[str, typing.Any], con
|
||||
'logsize': random.randint(0, 100), # Random log size
|
||||
'lognumber': random.randint(0, 100), # Random log number
|
||||
'address': f'{random.randint(0, 255)}.{random.randint(0, 255)}.{random.randint(0, 255)}.{random.randint(0, 255)}', # Random address
|
||||
'port': random.randint(0, 65535), # Random port
|
||||
'ipv6': random.choice([True, False]), # Random ipv6
|
||||
'workers': random.randint(1, 100), # Random workers, 0 will return as many as cpu cores
|
||||
'ssl_certificate': f'/tmp/uds_tunnel_{random.randint(0, 100)}.crt', # Random ssl certificate
|
||||
'ssl_certificate_key': f'/tmp/uds_tunnel_{random.randint(0, 100)}.key', # Random ssl certificate key
|
||||
@ -111,6 +115,7 @@ def get_config(**overrides) -> typing.Tuple[typing.Mapping[str, typing.Any], con
|
||||
'uds_verify_ssl': random.choice([True, False]), # Random verify uds ssl
|
||||
'secret': f'secret{random.randint(0, 100)}', # Random secret
|
||||
'allow': f'{random.randint(0, 255)}.0.0.0', # Random allow
|
||||
'use_uvloop': random.choice([True, False]), # Random use uvloop
|
||||
}
|
||||
values.update(overrides)
|
||||
config_file = io.StringIO(TEST_CONFIG.format(**values))
|
||||
|
@ -32,31 +32,48 @@ import typing
|
||||
import string
|
||||
import random
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import contextlib
|
||||
import socket
|
||||
import ssl
|
||||
import logging
|
||||
|
||||
from unittest import IsolatedAsyncioTestCase, mock
|
||||
|
||||
from uds_tunnel import proxy, tunnel, consts
|
||||
|
||||
from . import fixtures
|
||||
from .utils import tools
|
||||
from .utils import tools, certs
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from uds_tunnel import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
NOTIFY_TICKET = '0123456789cdef01456789abcdebcdef0123456789abcdef'
|
||||
UDS_GET_TICKET_RESPONSE = {
|
||||
'host': '127.0.0.1',
|
||||
'port': 54876,
|
||||
UDS_GET_TICKET_RESPONSE = lambda host, port: {
|
||||
'host': host,
|
||||
'port': port,
|
||||
'notify': NOTIFY_TICKET,
|
||||
}
|
||||
CALLER_HOST = ('host', 12345)
|
||||
REMOTE_HOST = ('127.0.0.1', 54876)
|
||||
|
||||
def uds_response(_, ticket: bytes, msg: str, queryParams: typing.Optional[typing.Mapping[str, str]] = None) -> typing.Dict[str, typing.Any]:
|
||||
|
||||
def uds_response(
|
||||
_,
|
||||
ticket: bytes,
|
||||
msg: str,
|
||||
queryParams: typing.Optional[typing.Mapping[str, str]] = None,
|
||||
) -> typing.Dict[str, typing.Any]:
|
||||
if msg == 'stop':
|
||||
return {}
|
||||
|
||||
return UDS_GET_TICKET_RESPONSE
|
||||
return UDS_GET_TICKET_RESPONSE(*REMOTE_HOST)
|
||||
|
||||
|
||||
class TestTunnel(IsolatedAsyncioTestCase):
|
||||
async def test_get_ticket_from_uds(self) -> None:
|
||||
async def test_get_ticket_from_uds_broker(self) -> None:
|
||||
_, cfg = fixtures.get_config()
|
||||
# Test some invalid tickets
|
||||
# Valid ticket are consts.TICKET_LENGTH bytes long, and must be A-Z, a-z, 0-9
|
||||
@ -65,7 +82,6 @@ class TestTunnel(IsolatedAsyncioTestCase):
|
||||
new_callable=tools.AsyncMock,
|
||||
) as m:
|
||||
m.side_effect = uds_response
|
||||
#m.return_value = UDS_GET_TICKET_RESPONSE
|
||||
for i in range(0, 100):
|
||||
ticket = ''.join(
|
||||
random.choices(
|
||||
@ -86,7 +102,7 @@ class TestTunnel(IsolatedAsyncioTestCase):
|
||||
cfg, ticket.encode(), CALLER_HOST
|
||||
)
|
||||
# Ensure data returned is correct {host, port, notify} from mock
|
||||
self.assertEqual(ret_value, m.return_value)
|
||||
self.assertEqual(ret_value, UDS_GET_TICKET_RESPONSE(*REMOTE_HOST))
|
||||
# Ensure mock was called with correct parameters
|
||||
print(m.call_args)
|
||||
# Check calling parameters, first one is the config, second one is the ticket, third one is the caller host
|
||||
@ -102,7 +118,7 @@ class TestTunnel(IsolatedAsyncioTestCase):
|
||||
# mock should have been called 100 times
|
||||
self.assertEqual(m.call_count, 100)
|
||||
|
||||
async def test_notify_end_to_uds(self) -> None:
|
||||
async def test_notify_end_to_uds_broker(self) -> None:
|
||||
_, cfg = fixtures.get_config()
|
||||
with mock.patch(
|
||||
'uds_tunnel.tunnel.TunnelProtocol._readFromUDS',
|
||||
@ -130,7 +146,7 @@ class TestTunnel(IsolatedAsyncioTestCase):
|
||||
# mock should have been called 100 times
|
||||
self.assertEqual(m.call_count, 100)
|
||||
|
||||
async def test_read_from_uds(self) -> None:
|
||||
async def test_read_from_uds_broker(self) -> None:
|
||||
# Generate a listening http server for testing UDS
|
||||
# Tesst fine responses:
|
||||
for use_ssl in (True, False):
|
||||
@ -142,7 +158,9 @@ class TestTunnel(IsolatedAsyncioTestCase):
|
||||
f'http{"s" if use_ssl else ""}://127.0.0.1:{server.port}/'
|
||||
)
|
||||
_, cfg = fixtures.get_config(
|
||||
uds_server=fake_uds_server, uds_verify_ssl=False
|
||||
uds_server=fake_uds_server,
|
||||
uds_verify_ssl=False,
|
||||
listen_protocol='http',
|
||||
)
|
||||
self.assertEqual(
|
||||
await TestTunnel.get(fake_uds_server),
|
||||
@ -155,6 +173,37 @@ class TestTunnel(IsolatedAsyncioTestCase):
|
||||
)
|
||||
self.assertEqual(ret, {'result': 'ok'})
|
||||
|
||||
async def test_tunnel_invalid_command(self) -> None:
|
||||
# Test invalid handshake
|
||||
# data = b''
|
||||
# future: asyncio.Future = asyncio.Future()
|
||||
|
||||
# def callback(ldata: bytes) -> None:
|
||||
# nonlocal data
|
||||
# data += ldata
|
||||
# future.set_result(True)
|
||||
|
||||
# Send invalid commands and see what happens
|
||||
# Commands are 4 bytes length, try with less and more invalid commands
|
||||
for i in range(0, 100, 10):
|
||||
# Set timeout to 1 seconds
|
||||
bad_cmd = bytes(random.randint(0, 255) for _ in range(i)) # Some garbage
|
||||
consts.TIMEOUT_COMMAND = 0.1 # type: ignore # timeout is a final variable, but we need to change it for testing speed
|
||||
logger.info(f'Testing invalid command with {bad_cmd!r}')
|
||||
async with TestTunnel.create_test_tunnel(lambda x: None) as cfg:
|
||||
# Open connection to tunnel
|
||||
async with TestTunnel.open_tunnel(cfg) as (reader, writer):
|
||||
# Send data
|
||||
writer.write(bad_cmd)
|
||||
await writer.drain()
|
||||
# Wait for response
|
||||
readed = await reader.read(1024)
|
||||
# Should return consts.ERROR_COMMAND or consts.ERROR_TIMEOUT
|
||||
if len(bad_cmd) < 4:
|
||||
self.assertEqual(readed, consts.RESPONSE_ERROR_TIMEOUT)
|
||||
else:
|
||||
self.assertEqual(readed, consts.RESPONSE_ERROR_COMMAND)
|
||||
|
||||
# Helpers
|
||||
@staticmethod
|
||||
async def get(url: str) -> str:
|
||||
@ -165,3 +214,74 @@ class TestTunnel(IsolatedAsyncioTestCase):
|
||||
async with session.get(url, **options) as r:
|
||||
r.raise_for_status()
|
||||
return await r.text()
|
||||
|
||||
@staticmethod
|
||||
async def create_tunnel_server(
|
||||
cfg: 'config.ConfigurationType', context: 'ssl.SSLContext'
|
||||
) -> 'asyncio.Server':
|
||||
# Create fake proxy
|
||||
proxy = mock.MagicMock()
|
||||
proxy.cfg = cfg
|
||||
proxy.ns = mock.MagicMock()
|
||||
proxy.ns.current = 0
|
||||
proxy.ns.total = 0
|
||||
proxy.ns.sent = 0
|
||||
proxy.ns.recv = 0
|
||||
proxy.counter = 0
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
# Create an asyncio listen socket on cfg.listen_host:cfg.listen_port
|
||||
return await loop.create_server(
|
||||
lambda: tunnel.TunnelProtocol(proxy),
|
||||
cfg.listen_address,
|
||||
cfg.listen_port,
|
||||
ssl=context,
|
||||
family=socket.AF_INET6 if cfg.listen_ipv6 or ':' in cfg.listen_address else socket.AF_INET,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@contextlib.asynccontextmanager
|
||||
async def create_test_tunnel(callback: typing.Callable[[bytes], None]) -> typing.AsyncGenerator['config.ConfigurationType', None]:
|
||||
# Generate a listening server for testing tunnel
|
||||
# Prepare the end of the tunnel
|
||||
async with tools.AsyncTCPServer(port=54876, callback=callback) as server:
|
||||
# Create a tunnel to localhost 13579
|
||||
# SSl cert for tunnel server
|
||||
with certs.ssl_context(server.host) as (ssl_ctx, _):
|
||||
_, cfg = fixtures.get_config(
|
||||
address=server.host,
|
||||
port=7777,
|
||||
)
|
||||
with mock.patch(
|
||||
'uds_tunnel.tunnel.TunnelProtocol._readFromUDS',
|
||||
new_callable=tools.AsyncMock,
|
||||
) as m:
|
||||
m.return_value = UDS_GET_TICKET_RESPONSE(server.host, server.port)
|
||||
|
||||
tunnel_server = await TestTunnel.create_tunnel_server(cfg, ssl_ctx)
|
||||
yield cfg
|
||||
tunnel_server.close()
|
||||
await tunnel_server.wait_closed()
|
||||
|
||||
@staticmethod
|
||||
@contextlib.asynccontextmanager
|
||||
async def open_tunnel(
|
||||
cfg: 'config.ConfigurationType',
|
||||
) -> typing.AsyncGenerator[typing.Tuple[asyncio.StreamReader, asyncio.StreamWriter], None]:
|
||||
""" opens an ssl socket to the tunnel server
|
||||
"""
|
||||
if cfg.listen_ipv6 or ':' in cfg.listen_address:
|
||||
family = socket.AF_INET6
|
||||
else:
|
||||
family = socket.AF_INET
|
||||
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
context.check_hostname = False
|
||||
context.verify_mode = ssl.CERT_NONE
|
||||
reader, writer = await asyncio.open_connection(
|
||||
cfg.listen_address, cfg.listen_port, ssl=context, family=family
|
||||
)
|
||||
yield reader, writer
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
|
||||
|
@ -1,8 +1,12 @@
|
||||
import secrets
|
||||
import random
|
||||
from datetime import datetime, timedelta
|
||||
import datetime
|
||||
import tempfile
|
||||
import ipaddress
|
||||
import typing
|
||||
import ssl
|
||||
import os
|
||||
import contextlib
|
||||
|
||||
from cryptography import x509
|
||||
from cryptography.x509.oid import NameOID
|
||||
@ -25,7 +29,7 @@ def selfSignedCert(ip: str) -> typing.Tuple[str, str, str]:
|
||||
san = x509.SubjectAlternativeName([x509.IPAddress(ipaddress.ip_address(ip))])
|
||||
|
||||
basic_contraints = x509.BasicConstraints(ca=True, path_length=0)
|
||||
now = datetime.utcnow()
|
||||
now = datetime.datetime.utcnow()
|
||||
cert = (
|
||||
x509.CertificateBuilder()
|
||||
.subject_name(name)
|
||||
@ -33,7 +37,7 @@ def selfSignedCert(ip: str) -> typing.Tuple[str, str, str]:
|
||||
.public_key(key.public_key())
|
||||
.serial_number(random.SystemRandom().randint(0, 1 << 64))
|
||||
.not_valid_before(now)
|
||||
.not_valid_after(now + timedelta(days=10 * 365))
|
||||
.not_valid_after(now + datetime.timedelta(days=10 * 365))
|
||||
.add_extension(basic_contraints, False)
|
||||
.add_extension(san, False)
|
||||
.sign(key, hashes.SHA256(), default_backend())
|
||||
@ -50,3 +54,47 @@ def selfSignedCert(ip: str) -> typing.Tuple[str, str, str]:
|
||||
cert.public_bytes(encoding=serialization.Encoding.PEM).decode(),
|
||||
password,
|
||||
)
|
||||
|
||||
|
||||
def sslContext(ip: str) -> typing.Tuple[ssl.SSLContext, str, str]:
|
||||
"""Returns an ssl context an the certificate & password for an ip
|
||||
|
||||
Args:
|
||||
ip (str): Ip for subject name
|
||||
|
||||
Returns:
|
||||
typing.Tuple[ssl.SSLContext, str, str]: ssl context, certificate file and password
|
||||
"""
|
||||
# First, create server cert and key on temp dir
|
||||
tmpdir = tempfile.gettempdir()
|
||||
cert, key, password = selfSignedCert('127.0.0.1')
|
||||
cert_file = f'{tmpdir}/tmp_cert.pem'
|
||||
with open(cert_file, 'w') as f:
|
||||
f.write(key)
|
||||
f.write(cert)
|
||||
# Create SSL context
|
||||
ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
ssl_ctx.options |= ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1
|
||||
ssl_ctx.load_cert_chain(certfile=f'{tmpdir}/tmp_cert.pem', password=password)
|
||||
ssl_ctx.check_hostname = False
|
||||
ssl_ctx.set_ciphers('ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384')
|
||||
|
||||
return ssl_ctx, cert_file, password
|
||||
|
||||
@contextlib.contextmanager
|
||||
def ssl_context(ip: str) -> typing.Generator[typing.Tuple[ssl.SSLContext, str], None, None]:
|
||||
"""Returns an ssl context for an ip
|
||||
|
||||
Args:
|
||||
ip (str): Ip for subject name
|
||||
|
||||
Returns:
|
||||
ssl.SSLContext: ssl context
|
||||
"""
|
||||
# First, create server cert and key on temp dir
|
||||
ssl_ctx, cert_file, password = sslContext(ip)
|
||||
|
||||
yield ssl_ctx, cert_file
|
||||
|
||||
# Remove cert file
|
||||
os.remove(cert_file)
|
||||
|
@ -37,6 +37,7 @@ from unittest import mock
|
||||
|
||||
from . import certs
|
||||
|
||||
|
||||
class AsyncMock(mock.MagicMock):
|
||||
async def __call__(self, *args, **kwargs):
|
||||
return super().__call__(*args, **kwargs)
|
||||
@ -44,42 +45,37 @@ class AsyncMock(mock.MagicMock):
|
||||
|
||||
# simple async http server, will return 200 OK with the request path as body
|
||||
class AsyncHttpServer:
|
||||
host: str
|
||||
port: int
|
||||
_server: typing.Optional[asyncio.AbstractServer]
|
||||
_response: typing.Optional[bytes]
|
||||
_ssl_ctx: typing.Optional[ssl.SSLContext]
|
||||
_ssl_cert_file: typing.Optional[str]
|
||||
|
||||
def __init__(
|
||||
self, port: int, *, response: typing.Optional[bytes] = None, use_ssl: bool = False,
|
||||
self,
|
||||
port: int,
|
||||
*,
|
||||
response: typing.Optional[bytes] = None,
|
||||
use_ssl: bool = False,
|
||||
host: str = '127.0.0.1' # ip
|
||||
):
|
||||
) -> None:
|
||||
self.host = host
|
||||
self.port = port
|
||||
self._server = None
|
||||
self._response = response
|
||||
if use_ssl:
|
||||
# First, create server cert and key on temp dir
|
||||
tmpdir = tempfile.gettempdir()
|
||||
cert, key, password = certs.selfSignedCert('127.0.0.1')
|
||||
with open(f'{tmpdir}/tmp_cert.pem', 'w') as f:
|
||||
f.write(key)
|
||||
f.write(cert)
|
||||
# Create SSL context
|
||||
self._ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
self._ssl_ctx.options |= ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1
|
||||
self._ssl_ctx.load_cert_chain(certfile=f'{tmpdir}/tmp_cert.pem', password=password)
|
||||
self._ssl_ctx.check_hostname = False
|
||||
self._ssl_ctx.set_ciphers(
|
||||
'ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384'
|
||||
)
|
||||
self._ssl_ctx, self._ssl_cert_file, pwd = certs.sslContext(host)
|
||||
else:
|
||||
self._ssl_ctx = None
|
||||
self._ssl_cert_file = None
|
||||
|
||||
# on end, remove certs
|
||||
def __del__(self):
|
||||
tmpdir = tempfile.gettempdir()
|
||||
# os.remove(f'{tmpdir}/tmp_cert.pem')
|
||||
def __del__(self) -> None:
|
||||
if self._ssl_cert_file:
|
||||
os.unlink(self._ssl_cert_file)
|
||||
|
||||
async def _handle(self, reader, writer):
|
||||
async def _handle(self, reader, writer) -> None:
|
||||
data = await reader.read(2048)
|
||||
path: bytes = data.split()[1]
|
||||
if self._response is not None:
|
||||
@ -90,20 +86,69 @@ class AsyncHttpServer:
|
||||
)
|
||||
await writer.drain()
|
||||
|
||||
async def __aenter__(self):
|
||||
async def __aenter__(self) -> 'AsyncHttpServer':
|
||||
if self._ssl_ctx is not None:
|
||||
self._server = await asyncio.start_server(
|
||||
self._handle, '127.0.0.1', self.port, ssl=self._ssl_ctx
|
||||
self._handle, self.host, self.port, ssl=self._ssl_ctx
|
||||
)
|
||||
else:
|
||||
self._server = await asyncio.start_server(
|
||||
self._handle, '127.0.0.1', self.port
|
||||
self._handle, self.host, self.port
|
||||
)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||
if self._server is not None:
|
||||
self._server.close()
|
||||
await self._server.wait_closed()
|
||||
self._server = None
|
||||
|
||||
class AsyncTCPServer:
|
||||
host: str
|
||||
port: int
|
||||
_server: typing.Optional[asyncio.AbstractServer]
|
||||
_response: typing.Optional[bytes]
|
||||
_callback: typing.Optional[typing.Callable[[bytes], None]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
port: int,
|
||||
*,
|
||||
response: typing.Optional[bytes] = None,
|
||||
host: str = '127.0.0.1', # ip
|
||||
callback: typing.Optional[typing.Callable[[bytes], None]] = None
|
||||
) -> None:
|
||||
self.host = host
|
||||
self.port = port
|
||||
self._server = None
|
||||
self._response = response
|
||||
self._callback = callback
|
||||
|
||||
self.data = b''
|
||||
|
||||
async def _handle(self, reader, writer) -> None:
|
||||
data = await reader.read(2048)
|
||||
|
||||
if self._callback:
|
||||
self._callback(data)
|
||||
|
||||
if self._response is not None:
|
||||
data = self._response
|
||||
else:
|
||||
data = b'sample data'
|
||||
writer.write(data)
|
||||
await writer.drain()
|
||||
|
||||
async def __aenter__(self) -> 'AsyncTCPServer':
|
||||
self._server = await asyncio.start_server(
|
||||
self._handle,
|
||||
self.host,
|
||||
self.port
|
||||
)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||
if self._server is not None:
|
||||
self._server.close()
|
||||
await self._server.wait_closed()
|
||||
self._server = None
|
||||
|
Loading…
x
Reference in New Issue
Block a user