1
0
mirror of https://github.com/dkmstr/openuds.git synced 2024-12-22 13:34:04 +03:00

backported 4.0 version improvements

This commit is contained in:
Adolfo Gómez García 2023-05-21 16:23:18 +02:00
parent 392cb6e406
commit 2c77d361d7
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
3 changed files with 79 additions and 54 deletions

View File

@ -74,11 +74,17 @@ class Proxy:
# Handshake correct in this point, upgrade the connection to TSL and let # Handshake correct in this point, upgrade the connection to TSL and let
# the protocol controller do the rest # the protocol controller do the rest
# Store source ip and port, for logging purposes in case of error
src_ip, src_port = (source.getpeername() if source else ('Unknown', 0))[:2] # May be ipv4 or ipv6, so we get only first two elements
# Upgrade connection to SSL, and use asyncio to handle the rest # Upgrade connection to SSL, and use asyncio to handle the rest
tun: typing.Optional[tunnel.TunnelProtocol] = None
try: try:
tun = tunnel.TunnelProtocol(self)
# (connect accepted loop not present on AbastractEventLoop definition < 3.10), that's why we use ignore # (connect accepted loop not present on AbastractEventLoop definition < 3.10), that's why we use ignore
await loop.connect_accepted_socket( # type: ignore await loop.connect_accepted_socket( # type: ignore
lambda: tunnel.TunnelProtocol(self), source, ssl=context lambda: tun, source, ssl=context,
ssl_handshake_timeout=3,
) )
# Wait for connection to be closed # Wait for connection to be closed
@ -86,6 +92,11 @@ class Proxy:
except asyncio.CancelledError: except asyncio.CancelledError:
pass # Return on cancel pass # Return on cancel
except Exception as e:
# Any other exception, ensure we close the connection
logger.error('ERROR on %s:%s: %s', src_ip, src_port, e)
if tun:
tun.close_connection()
logger.debug('Proxy finished') logger.debug('Proxy finished')

View File

@ -97,8 +97,6 @@ class TunnelProtocol(asyncio.Protocol):
# We start processing command # We start processing command
# After command, we can process stats or do_proxy, that is the "normal" operation # After command, we can process stats or do_proxy, that is the "normal" operation
self.runner = self.do_command self.runner = self.do_command
# Set starting timeout task, se we dont get hunged on connections without data (or insufficient data)
self.set_timeout(self.owner.cfg.command_timeout)
def process_open(self) -> None: def process_open(self) -> None:
# Open Command has the ticket behind it # Open Command has the ticket behind it
@ -158,7 +156,7 @@ class TunnelProtocol(asyncio.Protocol):
self.transport.write(b'OK') self.transport.write(b'OK')
self.stats_manager.increment_connections() # Increment connections counters self.stats_manager.increment_connections() # Increment connections counters
except Exception as e: except Exception as e:
logger.error('Error opening connection: %s', e) logger.error('CONNECTION FAILED: %s', e)
self.close_connection() self.close_connection()
# add open other side to the loop # add open other side to the loop
@ -276,9 +274,10 @@ class TunnelProtocol(asyncio.Protocol):
def close_connection(self): def close_connection(self):
try: try:
self.clean_timeout() # If a timeout is set, clean it
if not self.transport.is_closing(): if not self.transport.is_closing():
self.transport.close() self.transport.close()
except Exception: except Exception: # nosec: best effort
pass # Ignore errors pass # Ignore errors
def notify_end(self): def notify_end(self):
@ -304,6 +303,10 @@ class TunnelProtocol(asyncio.Protocol):
def connection_made(self, transport: 'asyncio.transports.BaseTransport') -> None: def connection_made(self, transport: 'asyncio.transports.BaseTransport') -> None:
# We know for sure that the transport is a Transport. # We know for sure that the transport is a Transport.
# Set starting timeout task, se we dont get hunged on connections without data (or insufficient data)
self.set_timeout(self.owner.cfg.command_timeout)
self.transport = typing.cast('asyncio.transports.Transport', transport) self.transport = typing.cast('asyncio.transports.Transport', transport)
# Get source # Get source
self.source = self.transport.get_extra_info('peername') self.source = self.transport.get_extra_info('peername')

View File

@ -37,20 +37,22 @@ import argparse
import signal import signal
import ssl import ssl
import socket import socket
import logging import threading # event for stop notification
from concurrent.futures import ThreadPoolExecutor
# event for stop notification
import threading
import typing import typing
import logging
from logging.handlers import RotatingFileHandler
from concurrent.futures import ThreadPoolExecutor
try: try:
import uvloop import uvloop # type: ignore
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
except ImportError: except ImportError:
pass # no uvloop support pass # no uvloop support
try: try:
import setproctitle import setproctitle # type: ignore
except ImportError: except ImportError:
setproctitle = None # type: ignore setproctitle = None # type: ignore
@ -73,8 +75,6 @@ def stop_signal(signum: int, frame: typing.Any) -> None:
def setup_log(cfg: config.ConfigurationType) -> None: def setup_log(cfg: config.ConfigurationType) -> None:
from logging.handlers import RotatingFileHandler
# Update logging if needed # Update logging if needed
if cfg.logfile: if cfg.logfile:
fileh = RotatingFileHandler( fileh = RotatingFileHandler(
@ -96,9 +96,7 @@ def setup_log(cfg: config.ConfigurationType) -> None:
log.setLevel(cfg.loglevel) log.setLevel(cfg.loglevel)
handler = logging.StreamHandler(sys.stderr) handler = logging.StreamHandler(sys.stderr)
handler.setLevel(cfg.loglevel) handler.setLevel(cfg.loglevel)
formatter = logging.Formatter( formatter = logging.Formatter('%(levelname)s - %(message)s') # Basic log format, nice for syslog
'%(levelname)s - %(message)s'
) # Basic log format, nice for syslog
handler.setFormatter(formatter) handler.setFormatter(formatter)
log.addHandler(handler) log.addHandler(handler)
@ -107,25 +105,25 @@ def setup_log(cfg: config.ConfigurationType) -> None:
logger.debug('Configuration: %s', cfg) logger.debug('Configuration: %s', cfg)
async def tunnel_proc_async( async def tunnel_proc_async(pipe: 'Connection', cfg: config.ConfigurationType, ns: 'Namespace') -> None:
pipe: 'Connection', cfg: config.ConfigurationType, ns: 'Namespace'
) -> None:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
tasks: typing.List[asyncio.Task] = [] tasks: typing.List[asyncio.Task] = []
def add_autoremovable_task(task: asyncio.Task) -> None: def add_autoremovable_task(task: asyncio.Task) -> None:
tasks.append(task) tasks.append(task)
task.add_done_callback(tasks.remove)
def remove_task(task: asyncio.Task) -> None:
logger.debug('Removing task %s', task)
tasks.remove(task)
task.add_done_callback(remove_task)
def get_socket() -> typing.Tuple[typing.Optional[socket.socket], typing.Optional[typing.Tuple[str, int]]]: def get_socket() -> typing.Tuple[typing.Optional[socket.socket], typing.Optional[typing.Tuple[str, int]]]:
try: try:
while True: while True:
# Clear back event, for next data # Clear back event, for next data
msg: typing.Optional[ msg: typing.Optional[typing.Tuple[socket.socket, typing.Tuple[str, int]]] = pipe.recv()
typing.Tuple[socket.socket, typing.Tuple[str, int]]
] = pipe.recv()
if msg: if msg:
return msg return msg
except EOFError: except EOFError:
@ -152,14 +150,26 @@ async def tunnel_proc_async(
# Set min version from string (1.2 or 1.3) as ssl.TLSVersion.TLSv1_2 or ssl.TLSVersion.TLSv1_3 # Set min version from string (1.2 or 1.3) as ssl.TLSVersion.TLSv1_2 or ssl.TLSVersion.TLSv1_3
if cfg.ssl_min_tls_version in ('1.2', '1.3'): if cfg.ssl_min_tls_version in ('1.2', '1.3'):
context.minimum_version = getattr(ssl.TLSVersion, f'TLSv1_{cfg.ssl_min_tls_version.split(".")[1]}') try:
context.minimum_version = getattr(
ssl.TLSVersion, f'TLSv1_{cfg.ssl_min_tls_version.split(".")[1]}'
)
except Exception as e:
logger.exception('Setting min tls version failed: %s. Using defaults', e)
context.minimum_version = ssl.TLSVersion.TLSv1_2
# Any other value will be ignored # Any other value will be ignored
if cfg.ssl_ciphers: if cfg.ssl_ciphers:
try:
context.set_ciphers(cfg.ssl_ciphers) context.set_ciphers(cfg.ssl_ciphers)
except Exception as e:
logger.exception('Setting ciphers failed: %s. Using defaults', e)
if cfg.ssl_dhparam: if cfg.ssl_dhparam:
try:
context.load_dh_params(cfg.ssl_dhparam) context.load_dh_params(cfg.ssl_dhparam)
except Exception as e:
logger.exception('Loading dhparams failed: %s. Using defaults', e)
try: try:
while True: while True:
@ -168,11 +178,13 @@ async def tunnel_proc_async(
(sock, address) = await loop.run_in_executor(None, get_socket) (sock, address) = await loop.run_in_executor(None, get_socket)
if not sock: if not sock:
break # No more sockets, exit break # No more sockets, exit
logger.debug(f'CONNECTION from {address!r} (pid: {os.getpid()})') logger.debug('CONNECTION from %s (pid: %s)', address, os.getpid())
# Due to proxy contains an "event" to stop, we need to create a new one for each connection # Due to proxy contains an "event" to stop, we need to create a new one for each connection
add_autoremovable_task(asyncio.create_task(proxy.Proxy(cfg, ns)(sock, context))) add_autoremovable_task(
except asyncio.CancelledError: asyncio.create_task(proxy.Proxy(cfg, ns)(sock, context), name=f'proxy-{address}')
raise )
except asyncio.CancelledError: # pylint: disable=try-except-raise
raise # Stop, but avoid generic exception
except Exception: except Exception:
logger.error('NEGOTIATION ERROR from %s', address[0] if address else 'unknown') logger.error('NEGOTIATION ERROR from %s', address[0] if address else 'unknown')
except asyncio.CancelledError: except asyncio.CancelledError:
@ -180,16 +192,20 @@ async def tunnel_proc_async(
# create task for server # create task for server
add_autoremovable_task(asyncio.create_task(run_server())) add_autoremovable_task(asyncio.create_task(run_server(), name='server'))
try: try:
while tasks and not do_stop.is_set(): while tasks and not do_stop.is_set():
to_wait = tasks[:] # Get a copy of the list to_wait = tasks[:] # Get a copy of the list
# Wait for "to_wait" tasks to finish, stop every 2 seconds to check if we need to stop # Wait for "to_wait" tasks to finish, stop every 2 seconds to check if we need to stop
done, _ = await asyncio.wait(to_wait, return_when=asyncio.FIRST_COMPLETED, timeout=2) # done, _ =
await asyncio.wait(to_wait, return_when=asyncio.FIRST_COMPLETED, timeout=2)
except asyncio.CancelledError: except asyncio.CancelledError:
logger.info('Task cancelled') logger.info('Task cancelled')
do_stop.set() # ensure we stop do_stop.set() # ensure we stop
except Exception:
logger.exception('Error in main loop')
do_stop.set()
logger.debug('Out of loop, stopping tasks: %s, running: %s', tasks, do_stop.is_set()) logger.debug('Out of loop, stopping tasks: %s, running: %s', tasks, do_stop.is_set())
@ -208,16 +224,15 @@ async def tunnel_proc_async(
logger.info('PROCESS %s stopped', os.getpid()) logger.info('PROCESS %s stopped', os.getpid())
def process_connection(
client: socket.socket, addr: typing.Tuple[str, str], conn: 'Connection' def process_connection(client: socket.socket, addr: typing.Tuple[str, str], conn: 'Connection') -> None:
) -> None:
data: bytes = b'' data: bytes = b''
try: try:
# First, ensure handshake (simple handshake) and command # First, ensure handshake (simple handshake) and command
data = client.recv(len(consts.HANDSHAKE_V1)) data = client.recv(len(consts.HANDSHAKE_V1))
if data != consts.HANDSHAKE_V1: if data != consts.HANDSHAKE_V1:
raise Exception('Invalid data from {}: {}'.format(addr[0], data.hex())) # Invalid handshake raise Exception(f'Invalid data from {addr[0]}: {data.hex()}') # Invalid handshake
conn.send((client, addr)) conn.send((client, addr))
del client # Ensure socket is controlled on child process del client # Ensure socket is controlled on child process
except Exception as e: except Exception as e:
@ -231,11 +246,11 @@ def tunnel_main(args: 'argparse.Namespace') -> None:
# Try to bind to port as running user # Try to bind to port as running user
# Wait for socket incoming connections and spread them # Wait for socket incoming connections and spread them
socket.setdefaulttimeout( socket.setdefaulttimeout(3.0) # So we can check for stop from time to time and not block forever
3.0 sock = socket.socket(
) # So we can check for stop from time to time and not block forever socket.AF_INET6 if args.ipv6 or cfg.ipv6 or ':' in cfg.listen_address else socket.AF_INET,
af_inet = socket.AF_INET6 if args.ipv6 or cfg.ipv6 or ':' in cfg.listen_address else socket.AF_INET socket.SOCK_STREAM,
sock = socket.socket(af_inet, socket.SOCK_STREAM) )
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
# We will not reuse port, we only want a UDS tunnel server running on a port # We will not reuse port, we only want a UDS tunnel server running on a port
@ -257,15 +272,13 @@ def tunnel_main(args: 'argparse.Namespace') -> None:
setup_log(cfg) setup_log(cfg)
logger.info( logger.info('Starting tunnel server on %s:%s', cfg.listen_address, cfg.listen_port)
'Starting tunnel server on %s:%s', cfg.listen_address, cfg.listen_port
)
if setproctitle: if setproctitle:
setproctitle.setproctitle(f'UDSTunnel {cfg.listen_address}:{cfg.listen_port}') setproctitle.setproctitle(f'UDSTunnel {cfg.listen_address}:{cfg.listen_port}')
# Create pid file # Create pid file
if cfg.pidfile: if cfg.pidfile:
with open(cfg.pidfile, mode='w') as f: with open(cfg.pidfile, mode='w', encoding='utf-8') as f:
f.write(str(os.getpid())) f.write(str(os.getpid()))
except Exception as e: except Exception as e:
@ -317,7 +330,7 @@ def tunnel_main(args: 'argparse.Namespace') -> None:
if cfg.pidfile: if cfg.pidfile:
os.unlink(cfg.pidfile) os.unlink(cfg.pidfile)
except Exception: except Exception:
pass logger.warning('Could not remove pidfile %s', cfg.pidfile)
logger.info('FINISHED') logger.info('FINISHED')
@ -325,9 +338,7 @@ def tunnel_main(args: 'argparse.Namespace') -> None:
def main() -> None: def main() -> None:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
group = parser.add_mutually_exclusive_group() group = parser.add_mutually_exclusive_group()
group.add_argument( group.add_argument('-t', '--tunnel', help='Starts the tunnel server', action='store_true')
'-t', '--tunnel', help='Starts the tunnel server', action='store_true'
)
# group.add_argument('-r', '--rdp', help='RDP Tunnel for traffic accounting') # group.add_argument('-r', '--rdp', help='RDP Tunnel for traffic accounting')
group.add_argument( group.add_argument(
'-s', '-s',