mirror of
https://github.com/dkmstr/openuds.git
synced 2024-12-22 13:34:04 +03:00
backported 4.0 version improvements
This commit is contained in:
parent
392cb6e406
commit
2c77d361d7
@ -74,11 +74,17 @@ class Proxy:
|
||||
# Handshake correct in this point, upgrade the connection to TSL and let
|
||||
# the protocol controller do the rest
|
||||
|
||||
# Store source ip and port, for logging purposes in case of error
|
||||
src_ip, src_port = (source.getpeername() if source else ('Unknown', 0))[:2] # May be ipv4 or ipv6, so we get only first two elements
|
||||
|
||||
# Upgrade connection to SSL, and use asyncio to handle the rest
|
||||
tun: typing.Optional[tunnel.TunnelProtocol] = None
|
||||
try:
|
||||
tun = tunnel.TunnelProtocol(self)
|
||||
# (connect accepted loop not present on AbastractEventLoop definition < 3.10), that's why we use ignore
|
||||
await loop.connect_accepted_socket( # type: ignore
|
||||
lambda: tunnel.TunnelProtocol(self), source, ssl=context
|
||||
lambda: tun, source, ssl=context,
|
||||
ssl_handshake_timeout=3,
|
||||
)
|
||||
|
||||
# Wait for connection to be closed
|
||||
@ -86,6 +92,11 @@ class Proxy:
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass # Return on cancel
|
||||
except Exception as e:
|
||||
# Any other exception, ensure we close the connection
|
||||
logger.error('ERROR on %s:%s: %s', src_ip, src_port, e)
|
||||
if tun:
|
||||
tun.close_connection()
|
||||
|
||||
logger.debug('Proxy finished')
|
||||
|
||||
|
@ -97,8 +97,6 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
# We start processing command
|
||||
# After command, we can process stats or do_proxy, that is the "normal" operation
|
||||
self.runner = self.do_command
|
||||
# Set starting timeout task, se we dont get hunged on connections without data (or insufficient data)
|
||||
self.set_timeout(self.owner.cfg.command_timeout)
|
||||
|
||||
def process_open(self) -> None:
|
||||
# Open Command has the ticket behind it
|
||||
@ -158,7 +156,7 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
self.transport.write(b'OK')
|
||||
self.stats_manager.increment_connections() # Increment connections counters
|
||||
except Exception as e:
|
||||
logger.error('Error opening connection: %s', e)
|
||||
logger.error('CONNECTION FAILED: %s', e)
|
||||
self.close_connection()
|
||||
|
||||
# add open other side to the loop
|
||||
@ -276,9 +274,10 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
|
||||
def close_connection(self):
|
||||
try:
|
||||
self.clean_timeout() # If a timeout is set, clean it
|
||||
if not self.transport.is_closing():
|
||||
self.transport.close()
|
||||
except Exception:
|
||||
except Exception: # nosec: best effort
|
||||
pass # Ignore errors
|
||||
|
||||
def notify_end(self):
|
||||
@ -304,6 +303,10 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
|
||||
def connection_made(self, transport: 'asyncio.transports.BaseTransport') -> None:
|
||||
# We know for sure that the transport is a Transport.
|
||||
|
||||
# Set starting timeout task, se we dont get hunged on connections without data (or insufficient data)
|
||||
self.set_timeout(self.owner.cfg.command_timeout)
|
||||
|
||||
self.transport = typing.cast('asyncio.transports.Transport', transport)
|
||||
# Get source
|
||||
self.source = self.transport.get_extra_info('peername')
|
||||
|
@ -37,20 +37,22 @@ import argparse
|
||||
import signal
|
||||
import ssl
|
||||
import socket
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
# event for stop notification
|
||||
import threading
|
||||
import threading # event for stop notification
|
||||
import typing
|
||||
import logging
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
|
||||
try:
|
||||
import uvloop
|
||||
import uvloop # type: ignore
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
except ImportError:
|
||||
pass # no uvloop support
|
||||
|
||||
try:
|
||||
import setproctitle
|
||||
import setproctitle # type: ignore
|
||||
except ImportError:
|
||||
setproctitle = None # type: ignore
|
||||
|
||||
@ -73,8 +75,6 @@ def stop_signal(signum: int, frame: typing.Any) -> None:
|
||||
|
||||
|
||||
def setup_log(cfg: config.ConfigurationType) -> None:
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
# Update logging if needed
|
||||
if cfg.logfile:
|
||||
fileh = RotatingFileHandler(
|
||||
@ -96,9 +96,7 @@ def setup_log(cfg: config.ConfigurationType) -> None:
|
||||
log.setLevel(cfg.loglevel)
|
||||
handler = logging.StreamHandler(sys.stderr)
|
||||
handler.setLevel(cfg.loglevel)
|
||||
formatter = logging.Formatter(
|
||||
'%(levelname)s - %(message)s'
|
||||
) # Basic log format, nice for syslog
|
||||
formatter = logging.Formatter('%(levelname)s - %(message)s') # Basic log format, nice for syslog
|
||||
handler.setFormatter(formatter)
|
||||
log.addHandler(handler)
|
||||
|
||||
@ -107,25 +105,25 @@ def setup_log(cfg: config.ConfigurationType) -> None:
|
||||
logger.debug('Configuration: %s', cfg)
|
||||
|
||||
|
||||
async def tunnel_proc_async(
|
||||
pipe: 'Connection', cfg: config.ConfigurationType, ns: 'Namespace'
|
||||
) -> None:
|
||||
|
||||
async def tunnel_proc_async(pipe: 'Connection', cfg: config.ConfigurationType, ns: 'Namespace') -> None:
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
tasks: typing.List[asyncio.Task] = []
|
||||
|
||||
def add_autoremovable_task(task: asyncio.Task) -> None:
|
||||
tasks.append(task)
|
||||
task.add_done_callback(tasks.remove)
|
||||
|
||||
def remove_task(task: asyncio.Task) -> None:
|
||||
logger.debug('Removing task %s', task)
|
||||
tasks.remove(task)
|
||||
|
||||
task.add_done_callback(remove_task)
|
||||
|
||||
def get_socket() -> typing.Tuple[typing.Optional[socket.socket], typing.Optional[typing.Tuple[str, int]]]:
|
||||
try:
|
||||
while True:
|
||||
# Clear back event, for next data
|
||||
msg: typing.Optional[
|
||||
typing.Tuple[socket.socket, typing.Tuple[str, int]]
|
||||
] = pipe.recv()
|
||||
msg: typing.Optional[typing.Tuple[socket.socket, typing.Tuple[str, int]]] = pipe.recv()
|
||||
if msg:
|
||||
return msg
|
||||
except EOFError:
|
||||
@ -147,19 +145,31 @@ async def tunnel_proc_async(
|
||||
args['keyfile'] = cfg.ssl_certificate_key
|
||||
if cfg.ssl_password:
|
||||
args['password'] = cfg.ssl_password
|
||||
|
||||
|
||||
context.load_cert_chain(**args)
|
||||
|
||||
# Set min version from string (1.2 or 1.3) as ssl.TLSVersion.TLSv1_2 or ssl.TLSVersion.TLSv1_3
|
||||
if cfg.ssl_min_tls_version in ('1.2', '1.3'):
|
||||
context.minimum_version = getattr(ssl.TLSVersion, f'TLSv1_{cfg.ssl_min_tls_version.split(".")[1]}')
|
||||
try:
|
||||
context.minimum_version = getattr(
|
||||
ssl.TLSVersion, f'TLSv1_{cfg.ssl_min_tls_version.split(".")[1]}'
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception('Setting min tls version failed: %s. Using defaults', e)
|
||||
context.minimum_version = ssl.TLSVersion.TLSv1_2
|
||||
# Any other value will be ignored
|
||||
|
||||
if cfg.ssl_ciphers:
|
||||
context.set_ciphers(cfg.ssl_ciphers)
|
||||
try:
|
||||
context.set_ciphers(cfg.ssl_ciphers)
|
||||
except Exception as e:
|
||||
logger.exception('Setting ciphers failed: %s. Using defaults', e)
|
||||
|
||||
if cfg.ssl_dhparam:
|
||||
context.load_dh_params(cfg.ssl_dhparam)
|
||||
try:
|
||||
context.load_dh_params(cfg.ssl_dhparam)
|
||||
except Exception as e:
|
||||
logger.exception('Loading dhparams failed: %s. Using defaults', e)
|
||||
|
||||
try:
|
||||
while True:
|
||||
@ -168,28 +178,34 @@ async def tunnel_proc_async(
|
||||
(sock, address) = await loop.run_in_executor(None, get_socket)
|
||||
if not sock:
|
||||
break # No more sockets, exit
|
||||
logger.debug(f'CONNECTION from {address!r} (pid: {os.getpid()})')
|
||||
logger.debug('CONNECTION from %s (pid: %s)', address, os.getpid())
|
||||
# Due to proxy contains an "event" to stop, we need to create a new one for each connection
|
||||
add_autoremovable_task(asyncio.create_task(proxy.Proxy(cfg, ns)(sock, context)))
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
add_autoremovable_task(
|
||||
asyncio.create_task(proxy.Proxy(cfg, ns)(sock, context), name=f'proxy-{address}')
|
||||
)
|
||||
except asyncio.CancelledError: # pylint: disable=try-except-raise
|
||||
raise # Stop, but avoid generic exception
|
||||
except Exception:
|
||||
logger.error('NEGOTIATION ERROR from %s', address[0] if address else 'unknown')
|
||||
except asyncio.CancelledError:
|
||||
pass # Stop
|
||||
|
||||
# create task for server
|
||||
|
||||
add_autoremovable_task(asyncio.create_task(run_server()))
|
||||
|
||||
add_autoremovable_task(asyncio.create_task(run_server(), name='server'))
|
||||
|
||||
try:
|
||||
while tasks and not do_stop.is_set():
|
||||
to_wait = tasks[:] # Get a copy of the list
|
||||
# Wait for "to_wait" tasks to finish, stop every 2 seconds to check if we need to stop
|
||||
done, _ = await asyncio.wait(to_wait, return_when=asyncio.FIRST_COMPLETED, timeout=2)
|
||||
# done, _ =
|
||||
await asyncio.wait(to_wait, return_when=asyncio.FIRST_COMPLETED, timeout=2)
|
||||
except asyncio.CancelledError:
|
||||
logger.info('Task cancelled')
|
||||
do_stop.set() # ensure we stop
|
||||
except Exception:
|
||||
logger.exception('Error in main loop')
|
||||
do_stop.set()
|
||||
|
||||
logger.debug('Out of loop, stopping tasks: %s, running: %s', tasks, do_stop.is_set())
|
||||
|
||||
@ -199,7 +215,7 @@ async def tunnel_proc_async(
|
||||
task.cancel()
|
||||
except asyncio.CancelledError:
|
||||
pass # Ignore, we are stopping
|
||||
|
||||
|
||||
# for task in tasks:
|
||||
# task.cancel()
|
||||
|
||||
@ -208,16 +224,15 @@ async def tunnel_proc_async(
|
||||
|
||||
logger.info('PROCESS %s stopped', os.getpid())
|
||||
|
||||
def process_connection(
|
||||
client: socket.socket, addr: typing.Tuple[str, str], conn: 'Connection'
|
||||
) -> None:
|
||||
|
||||
def process_connection(client: socket.socket, addr: typing.Tuple[str, str], conn: 'Connection') -> None:
|
||||
data: bytes = b''
|
||||
try:
|
||||
# First, ensure handshake (simple handshake) and command
|
||||
data = client.recv(len(consts.HANDSHAKE_V1))
|
||||
|
||||
if data != consts.HANDSHAKE_V1:
|
||||
raise Exception('Invalid data from {}: {}'.format(addr[0], data.hex())) # Invalid handshake
|
||||
raise Exception(f'Invalid data from {addr[0]}: {data.hex()}') # Invalid handshake
|
||||
conn.send((client, addr))
|
||||
del client # Ensure socket is controlled on child process
|
||||
except Exception as e:
|
||||
@ -231,11 +246,11 @@ def tunnel_main(args: 'argparse.Namespace') -> None:
|
||||
|
||||
# Try to bind to port as running user
|
||||
# Wait for socket incoming connections and spread them
|
||||
socket.setdefaulttimeout(
|
||||
3.0
|
||||
) # So we can check for stop from time to time and not block forever
|
||||
af_inet = socket.AF_INET6 if args.ipv6 or cfg.ipv6 or ':' in cfg.listen_address else socket.AF_INET
|
||||
sock = socket.socket(af_inet, socket.SOCK_STREAM)
|
||||
socket.setdefaulttimeout(3.0) # So we can check for stop from time to time and not block forever
|
||||
sock = socket.socket(
|
||||
socket.AF_INET6 if args.ipv6 or cfg.ipv6 or ':' in cfg.listen_address else socket.AF_INET,
|
||||
socket.SOCK_STREAM,
|
||||
)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True)
|
||||
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
||||
# We will not reuse port, we only want a UDS tunnel server running on a port
|
||||
@ -257,15 +272,13 @@ def tunnel_main(args: 'argparse.Namespace') -> None:
|
||||
|
||||
setup_log(cfg)
|
||||
|
||||
logger.info(
|
||||
'Starting tunnel server on %s:%s', cfg.listen_address, cfg.listen_port
|
||||
)
|
||||
logger.info('Starting tunnel server on %s:%s', cfg.listen_address, cfg.listen_port)
|
||||
if setproctitle:
|
||||
setproctitle.setproctitle(f'UDSTunnel {cfg.listen_address}:{cfg.listen_port}')
|
||||
|
||||
# Create pid file
|
||||
if cfg.pidfile:
|
||||
with open(cfg.pidfile, mode='w') as f:
|
||||
with open(cfg.pidfile, mode='w', encoding='utf-8') as f:
|
||||
f.write(str(os.getpid()))
|
||||
|
||||
except Exception as e:
|
||||
@ -278,7 +291,7 @@ def tunnel_main(args: 'argparse.Namespace') -> None:
|
||||
signal.signal(signal.SIGINT, stop_signal)
|
||||
signal.signal(signal.SIGTERM, stop_signal)
|
||||
except Exception as e:
|
||||
# Signal not available on threads, and we use threads on tests,
|
||||
# Signal not available on threads, and we use threads on tests,
|
||||
# so we will ignore this because on tests signals are not important
|
||||
logger.warning('Signal not available: %s', e)
|
||||
|
||||
@ -317,7 +330,7 @@ def tunnel_main(args: 'argparse.Namespace') -> None:
|
||||
if cfg.pidfile:
|
||||
os.unlink(cfg.pidfile)
|
||||
except Exception:
|
||||
pass
|
||||
logger.warning('Could not remove pidfile %s', cfg.pidfile)
|
||||
|
||||
logger.info('FINISHED')
|
||||
|
||||
@ -325,9 +338,7 @@ def tunnel_main(args: 'argparse.Namespace') -> None:
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
group = parser.add_mutually_exclusive_group()
|
||||
group.add_argument(
|
||||
'-t', '--tunnel', help='Starts the tunnel server', action='store_true'
|
||||
)
|
||||
group.add_argument('-t', '--tunnel', help='Starts the tunnel server', action='store_true')
|
||||
# group.add_argument('-r', '--rdp', help='RDP Tunnel for traffic accounting')
|
||||
group.add_argument(
|
||||
'-s',
|
||||
|
Loading…
Reference in New Issue
Block a user