mirror of
https://github.com/dkmstr/openuds.git
synced 2024-12-22 13:34:04 +03:00
backported 4.0 version improvements
This commit is contained in:
parent
392cb6e406
commit
2c77d361d7
@ -74,11 +74,17 @@ class Proxy:
|
|||||||
# Handshake correct in this point, upgrade the connection to TSL and let
|
# Handshake correct in this point, upgrade the connection to TSL and let
|
||||||
# the protocol controller do the rest
|
# the protocol controller do the rest
|
||||||
|
|
||||||
|
# Store source ip and port, for logging purposes in case of error
|
||||||
|
src_ip, src_port = (source.getpeername() if source else ('Unknown', 0))[:2] # May be ipv4 or ipv6, so we get only first two elements
|
||||||
|
|
||||||
# Upgrade connection to SSL, and use asyncio to handle the rest
|
# Upgrade connection to SSL, and use asyncio to handle the rest
|
||||||
|
tun: typing.Optional[tunnel.TunnelProtocol] = None
|
||||||
try:
|
try:
|
||||||
|
tun = tunnel.TunnelProtocol(self)
|
||||||
# (connect accepted loop not present on AbastractEventLoop definition < 3.10), that's why we use ignore
|
# (connect accepted loop not present on AbastractEventLoop definition < 3.10), that's why we use ignore
|
||||||
await loop.connect_accepted_socket( # type: ignore
|
await loop.connect_accepted_socket( # type: ignore
|
||||||
lambda: tunnel.TunnelProtocol(self), source, ssl=context
|
lambda: tun, source, ssl=context,
|
||||||
|
ssl_handshake_timeout=3,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Wait for connection to be closed
|
# Wait for connection to be closed
|
||||||
@ -86,6 +92,11 @@ class Proxy:
|
|||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass # Return on cancel
|
pass # Return on cancel
|
||||||
|
except Exception as e:
|
||||||
|
# Any other exception, ensure we close the connection
|
||||||
|
logger.error('ERROR on %s:%s: %s', src_ip, src_port, e)
|
||||||
|
if tun:
|
||||||
|
tun.close_connection()
|
||||||
|
|
||||||
logger.debug('Proxy finished')
|
logger.debug('Proxy finished')
|
||||||
|
|
||||||
|
@ -97,8 +97,6 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
# We start processing command
|
# We start processing command
|
||||||
# After command, we can process stats or do_proxy, that is the "normal" operation
|
# After command, we can process stats or do_proxy, that is the "normal" operation
|
||||||
self.runner = self.do_command
|
self.runner = self.do_command
|
||||||
# Set starting timeout task, se we dont get hunged on connections without data (or insufficient data)
|
|
||||||
self.set_timeout(self.owner.cfg.command_timeout)
|
|
||||||
|
|
||||||
def process_open(self) -> None:
|
def process_open(self) -> None:
|
||||||
# Open Command has the ticket behind it
|
# Open Command has the ticket behind it
|
||||||
@ -158,7 +156,7 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
self.transport.write(b'OK')
|
self.transport.write(b'OK')
|
||||||
self.stats_manager.increment_connections() # Increment connections counters
|
self.stats_manager.increment_connections() # Increment connections counters
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error('Error opening connection: %s', e)
|
logger.error('CONNECTION FAILED: %s', e)
|
||||||
self.close_connection()
|
self.close_connection()
|
||||||
|
|
||||||
# add open other side to the loop
|
# add open other side to the loop
|
||||||
@ -276,9 +274,10 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
|
|
||||||
def close_connection(self):
|
def close_connection(self):
|
||||||
try:
|
try:
|
||||||
|
self.clean_timeout() # If a timeout is set, clean it
|
||||||
if not self.transport.is_closing():
|
if not self.transport.is_closing():
|
||||||
self.transport.close()
|
self.transport.close()
|
||||||
except Exception:
|
except Exception: # nosec: best effort
|
||||||
pass # Ignore errors
|
pass # Ignore errors
|
||||||
|
|
||||||
def notify_end(self):
|
def notify_end(self):
|
||||||
@ -304,6 +303,10 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
|
|
||||||
def connection_made(self, transport: 'asyncio.transports.BaseTransport') -> None:
|
def connection_made(self, transport: 'asyncio.transports.BaseTransport') -> None:
|
||||||
# We know for sure that the transport is a Transport.
|
# We know for sure that the transport is a Transport.
|
||||||
|
|
||||||
|
# Set starting timeout task, se we dont get hunged on connections without data (or insufficient data)
|
||||||
|
self.set_timeout(self.owner.cfg.command_timeout)
|
||||||
|
|
||||||
self.transport = typing.cast('asyncio.transports.Transport', transport)
|
self.transport = typing.cast('asyncio.transports.Transport', transport)
|
||||||
# Get source
|
# Get source
|
||||||
self.source = self.transport.get_extra_info('peername')
|
self.source = self.transport.get_extra_info('peername')
|
||||||
|
@ -37,20 +37,22 @@ import argparse
|
|||||||
import signal
|
import signal
|
||||||
import ssl
|
import ssl
|
||||||
import socket
|
import socket
|
||||||
import logging
|
import threading # event for stop notification
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
# event for stop notification
|
|
||||||
import threading
|
|
||||||
import typing
|
import typing
|
||||||
|
import logging
|
||||||
|
from logging.handlers import RotatingFileHandler
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import uvloop
|
import uvloop # type: ignore
|
||||||
|
|
||||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass # no uvloop support
|
pass # no uvloop support
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import setproctitle
|
import setproctitle # type: ignore
|
||||||
except ImportError:
|
except ImportError:
|
||||||
setproctitle = None # type: ignore
|
setproctitle = None # type: ignore
|
||||||
|
|
||||||
@ -73,8 +75,6 @@ def stop_signal(signum: int, frame: typing.Any) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def setup_log(cfg: config.ConfigurationType) -> None:
|
def setup_log(cfg: config.ConfigurationType) -> None:
|
||||||
from logging.handlers import RotatingFileHandler
|
|
||||||
|
|
||||||
# Update logging if needed
|
# Update logging if needed
|
||||||
if cfg.logfile:
|
if cfg.logfile:
|
||||||
fileh = RotatingFileHandler(
|
fileh = RotatingFileHandler(
|
||||||
@ -96,9 +96,7 @@ def setup_log(cfg: config.ConfigurationType) -> None:
|
|||||||
log.setLevel(cfg.loglevel)
|
log.setLevel(cfg.loglevel)
|
||||||
handler = logging.StreamHandler(sys.stderr)
|
handler = logging.StreamHandler(sys.stderr)
|
||||||
handler.setLevel(cfg.loglevel)
|
handler.setLevel(cfg.loglevel)
|
||||||
formatter = logging.Formatter(
|
formatter = logging.Formatter('%(levelname)s - %(message)s') # Basic log format, nice for syslog
|
||||||
'%(levelname)s - %(message)s'
|
|
||||||
) # Basic log format, nice for syslog
|
|
||||||
handler.setFormatter(formatter)
|
handler.setFormatter(formatter)
|
||||||
log.addHandler(handler)
|
log.addHandler(handler)
|
||||||
|
|
||||||
@ -107,25 +105,25 @@ def setup_log(cfg: config.ConfigurationType) -> None:
|
|||||||
logger.debug('Configuration: %s', cfg)
|
logger.debug('Configuration: %s', cfg)
|
||||||
|
|
||||||
|
|
||||||
async def tunnel_proc_async(
|
async def tunnel_proc_async(pipe: 'Connection', cfg: config.ConfigurationType, ns: 'Namespace') -> None:
|
||||||
pipe: 'Connection', cfg: config.ConfigurationType, ns: 'Namespace'
|
|
||||||
) -> None:
|
|
||||||
|
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
tasks: typing.List[asyncio.Task] = []
|
tasks: typing.List[asyncio.Task] = []
|
||||||
|
|
||||||
def add_autoremovable_task(task: asyncio.Task) -> None:
|
def add_autoremovable_task(task: asyncio.Task) -> None:
|
||||||
tasks.append(task)
|
tasks.append(task)
|
||||||
task.add_done_callback(tasks.remove)
|
|
||||||
|
def remove_task(task: asyncio.Task) -> None:
|
||||||
|
logger.debug('Removing task %s', task)
|
||||||
|
tasks.remove(task)
|
||||||
|
|
||||||
|
task.add_done_callback(remove_task)
|
||||||
|
|
||||||
def get_socket() -> typing.Tuple[typing.Optional[socket.socket], typing.Optional[typing.Tuple[str, int]]]:
|
def get_socket() -> typing.Tuple[typing.Optional[socket.socket], typing.Optional[typing.Tuple[str, int]]]:
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
# Clear back event, for next data
|
# Clear back event, for next data
|
||||||
msg: typing.Optional[
|
msg: typing.Optional[typing.Tuple[socket.socket, typing.Tuple[str, int]]] = pipe.recv()
|
||||||
typing.Tuple[socket.socket, typing.Tuple[str, int]]
|
|
||||||
] = pipe.recv()
|
|
||||||
if msg:
|
if msg:
|
||||||
return msg
|
return msg
|
||||||
except EOFError:
|
except EOFError:
|
||||||
@ -152,14 +150,26 @@ async def tunnel_proc_async(
|
|||||||
|
|
||||||
# Set min version from string (1.2 or 1.3) as ssl.TLSVersion.TLSv1_2 or ssl.TLSVersion.TLSv1_3
|
# Set min version from string (1.2 or 1.3) as ssl.TLSVersion.TLSv1_2 or ssl.TLSVersion.TLSv1_3
|
||||||
if cfg.ssl_min_tls_version in ('1.2', '1.3'):
|
if cfg.ssl_min_tls_version in ('1.2', '1.3'):
|
||||||
context.minimum_version = getattr(ssl.TLSVersion, f'TLSv1_{cfg.ssl_min_tls_version.split(".")[1]}')
|
try:
|
||||||
|
context.minimum_version = getattr(
|
||||||
|
ssl.TLSVersion, f'TLSv1_{cfg.ssl_min_tls_version.split(".")[1]}'
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception('Setting min tls version failed: %s. Using defaults', e)
|
||||||
|
context.minimum_version = ssl.TLSVersion.TLSv1_2
|
||||||
# Any other value will be ignored
|
# Any other value will be ignored
|
||||||
|
|
||||||
if cfg.ssl_ciphers:
|
if cfg.ssl_ciphers:
|
||||||
|
try:
|
||||||
context.set_ciphers(cfg.ssl_ciphers)
|
context.set_ciphers(cfg.ssl_ciphers)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception('Setting ciphers failed: %s. Using defaults', e)
|
||||||
|
|
||||||
if cfg.ssl_dhparam:
|
if cfg.ssl_dhparam:
|
||||||
|
try:
|
||||||
context.load_dh_params(cfg.ssl_dhparam)
|
context.load_dh_params(cfg.ssl_dhparam)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception('Loading dhparams failed: %s. Using defaults', e)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
@ -168,11 +178,13 @@ async def tunnel_proc_async(
|
|||||||
(sock, address) = await loop.run_in_executor(None, get_socket)
|
(sock, address) = await loop.run_in_executor(None, get_socket)
|
||||||
if not sock:
|
if not sock:
|
||||||
break # No more sockets, exit
|
break # No more sockets, exit
|
||||||
logger.debug(f'CONNECTION from {address!r} (pid: {os.getpid()})')
|
logger.debug('CONNECTION from %s (pid: %s)', address, os.getpid())
|
||||||
# Due to proxy contains an "event" to stop, we need to create a new one for each connection
|
# Due to proxy contains an "event" to stop, we need to create a new one for each connection
|
||||||
add_autoremovable_task(asyncio.create_task(proxy.Proxy(cfg, ns)(sock, context)))
|
add_autoremovable_task(
|
||||||
except asyncio.CancelledError:
|
asyncio.create_task(proxy.Proxy(cfg, ns)(sock, context), name=f'proxy-{address}')
|
||||||
raise
|
)
|
||||||
|
except asyncio.CancelledError: # pylint: disable=try-except-raise
|
||||||
|
raise # Stop, but avoid generic exception
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.error('NEGOTIATION ERROR from %s', address[0] if address else 'unknown')
|
logger.error('NEGOTIATION ERROR from %s', address[0] if address else 'unknown')
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
@ -180,16 +192,20 @@ async def tunnel_proc_async(
|
|||||||
|
|
||||||
# create task for server
|
# create task for server
|
||||||
|
|
||||||
add_autoremovable_task(asyncio.create_task(run_server()))
|
add_autoremovable_task(asyncio.create_task(run_server(), name='server'))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while tasks and not do_stop.is_set():
|
while tasks and not do_stop.is_set():
|
||||||
to_wait = tasks[:] # Get a copy of the list
|
to_wait = tasks[:] # Get a copy of the list
|
||||||
# Wait for "to_wait" tasks to finish, stop every 2 seconds to check if we need to stop
|
# Wait for "to_wait" tasks to finish, stop every 2 seconds to check if we need to stop
|
||||||
done, _ = await asyncio.wait(to_wait, return_when=asyncio.FIRST_COMPLETED, timeout=2)
|
# done, _ =
|
||||||
|
await asyncio.wait(to_wait, return_when=asyncio.FIRST_COMPLETED, timeout=2)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.info('Task cancelled')
|
logger.info('Task cancelled')
|
||||||
do_stop.set() # ensure we stop
|
do_stop.set() # ensure we stop
|
||||||
|
except Exception:
|
||||||
|
logger.exception('Error in main loop')
|
||||||
|
do_stop.set()
|
||||||
|
|
||||||
logger.debug('Out of loop, stopping tasks: %s, running: %s', tasks, do_stop.is_set())
|
logger.debug('Out of loop, stopping tasks: %s, running: %s', tasks, do_stop.is_set())
|
||||||
|
|
||||||
@ -208,16 +224,15 @@ async def tunnel_proc_async(
|
|||||||
|
|
||||||
logger.info('PROCESS %s stopped', os.getpid())
|
logger.info('PROCESS %s stopped', os.getpid())
|
||||||
|
|
||||||
def process_connection(
|
|
||||||
client: socket.socket, addr: typing.Tuple[str, str], conn: 'Connection'
|
def process_connection(client: socket.socket, addr: typing.Tuple[str, str], conn: 'Connection') -> None:
|
||||||
) -> None:
|
|
||||||
data: bytes = b''
|
data: bytes = b''
|
||||||
try:
|
try:
|
||||||
# First, ensure handshake (simple handshake) and command
|
# First, ensure handshake (simple handshake) and command
|
||||||
data = client.recv(len(consts.HANDSHAKE_V1))
|
data = client.recv(len(consts.HANDSHAKE_V1))
|
||||||
|
|
||||||
if data != consts.HANDSHAKE_V1:
|
if data != consts.HANDSHAKE_V1:
|
||||||
raise Exception('Invalid data from {}: {}'.format(addr[0], data.hex())) # Invalid handshake
|
raise Exception(f'Invalid data from {addr[0]}: {data.hex()}') # Invalid handshake
|
||||||
conn.send((client, addr))
|
conn.send((client, addr))
|
||||||
del client # Ensure socket is controlled on child process
|
del client # Ensure socket is controlled on child process
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -231,11 +246,11 @@ def tunnel_main(args: 'argparse.Namespace') -> None:
|
|||||||
|
|
||||||
# Try to bind to port as running user
|
# Try to bind to port as running user
|
||||||
# Wait for socket incoming connections and spread them
|
# Wait for socket incoming connections and spread them
|
||||||
socket.setdefaulttimeout(
|
socket.setdefaulttimeout(3.0) # So we can check for stop from time to time and not block forever
|
||||||
3.0
|
sock = socket.socket(
|
||||||
) # So we can check for stop from time to time and not block forever
|
socket.AF_INET6 if args.ipv6 or cfg.ipv6 or ':' in cfg.listen_address else socket.AF_INET,
|
||||||
af_inet = socket.AF_INET6 if args.ipv6 or cfg.ipv6 or ':' in cfg.listen_address else socket.AF_INET
|
socket.SOCK_STREAM,
|
||||||
sock = socket.socket(af_inet, socket.SOCK_STREAM)
|
)
|
||||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True)
|
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True)
|
||||||
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
||||||
# We will not reuse port, we only want a UDS tunnel server running on a port
|
# We will not reuse port, we only want a UDS tunnel server running on a port
|
||||||
@ -257,15 +272,13 @@ def tunnel_main(args: 'argparse.Namespace') -> None:
|
|||||||
|
|
||||||
setup_log(cfg)
|
setup_log(cfg)
|
||||||
|
|
||||||
logger.info(
|
logger.info('Starting tunnel server on %s:%s', cfg.listen_address, cfg.listen_port)
|
||||||
'Starting tunnel server on %s:%s', cfg.listen_address, cfg.listen_port
|
|
||||||
)
|
|
||||||
if setproctitle:
|
if setproctitle:
|
||||||
setproctitle.setproctitle(f'UDSTunnel {cfg.listen_address}:{cfg.listen_port}')
|
setproctitle.setproctitle(f'UDSTunnel {cfg.listen_address}:{cfg.listen_port}')
|
||||||
|
|
||||||
# Create pid file
|
# Create pid file
|
||||||
if cfg.pidfile:
|
if cfg.pidfile:
|
||||||
with open(cfg.pidfile, mode='w') as f:
|
with open(cfg.pidfile, mode='w', encoding='utf-8') as f:
|
||||||
f.write(str(os.getpid()))
|
f.write(str(os.getpid()))
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -317,7 +330,7 @@ def tunnel_main(args: 'argparse.Namespace') -> None:
|
|||||||
if cfg.pidfile:
|
if cfg.pidfile:
|
||||||
os.unlink(cfg.pidfile)
|
os.unlink(cfg.pidfile)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
logger.warning('Could not remove pidfile %s', cfg.pidfile)
|
||||||
|
|
||||||
logger.info('FINISHED')
|
logger.info('FINISHED')
|
||||||
|
|
||||||
@ -325,9 +338,7 @@ def tunnel_main(args: 'argparse.Namespace') -> None:
|
|||||||
def main() -> None:
|
def main() -> None:
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
group = parser.add_mutually_exclusive_group()
|
group = parser.add_mutually_exclusive_group()
|
||||||
group.add_argument(
|
group.add_argument('-t', '--tunnel', help='Starts the tunnel server', action='store_true')
|
||||||
'-t', '--tunnel', help='Starts the tunnel server', action='store_true'
|
|
||||||
)
|
|
||||||
# group.add_argument('-r', '--rdp', help='RDP Tunnel for traffic accounting')
|
# group.add_argument('-r', '--rdp', help='RDP Tunnel for traffic accounting')
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
'-s',
|
'-s',
|
||||||
|
Loading…
Reference in New Issue
Block a user