1
0
mirror of https://github.com/dkmstr/openuds.git synced 2024-12-22 13:34:04 +03:00

updates from 4.0 backported

This commit is contained in:
Adolfo Gómez García 2022-12-19 01:25:29 +01:00
parent b7962a24f1
commit adeb6b2a46
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
8 changed files with 225 additions and 101 deletions

View File

@ -126,9 +126,7 @@ async def main():
data = client.recv(4)
print(data)
# Upgrade connection to SSL, and use asyncio to handle the rest
transport: 'asyncio.transports.Transport'
protocol: TunnelProtocol
(transport, protocol) = await loop.connect_accepted_socket( # type: ignore
(_, protocol) = await loop.connect_accepted_socket( # type: ignore
lambda: TunnelProtocol(), client, ssl=context
)

View File

@ -43,18 +43,21 @@ class ConfigurationType(typing.NamedTuple):
pidfile: str
user: str
log_level: str
log_file: str
log_size: int
log_number: int
loglevel: str
logfile: str
logsize: int
lognumber: int
listen_address: str
listen_port: int
ipv6: bool
workers: int
ssl_certificate: str
ssl_certificate_key: str
ssl_password: str
ssl_ciphers: str
ssl_dhparam: str
@ -111,15 +114,17 @@ def read(
return ConfigurationType(
pidfile=uds.get('pidfile', ''),
user=uds.get('user', ''),
log_level=uds.get('loglevel', 'ERROR'),
log_file=uds.get('logfile', ''),
log_size=int(logsize) * 1024 * 1024,
log_number=int(uds.get('lognumber', '3')),
loglevel=uds.get('loglevel', 'ERROR'),
logfile=uds.get('logfile', ''),
logsize=int(logsize) * 1024 * 1024,
lognumber=int(uds.get('lognumber', '3')),
listen_address=uds.get('address', '0.0.0.0'),
listen_port=int(uds.get('port', '443')),
ipv6=uds.get('ipv6', 'false').lower() == 'true',
workers=int(uds.get('workers', '0')) or multiprocessing.cpu_count(),
ssl_certificate=uds['ssl_certificate'],
ssl_certificate_key=uds['ssl_certificate_key'],
ssl_certificate_key=uds.get('ssl_certificate_key', ''),
ssl_password=uds.get('ssl_password', ''),
ssl_ciphers=uds.get('ssl_ciphers'),
ssl_dhparam=uds.get('ssl_dhparam'),
uds_server=uds_server,

View File

@ -28,34 +28,48 @@
'''
Author: Adolfo Gómez, dkmaster at dkmon dot com
'''
import typing
DEBUG = True
if DEBUG:
CONFIGFILE = 'udstunnel.conf'
LOGFORMAT = '%(levelname)s %(asctime)s %(message)s'
else:
CONFIGFILE = '/etc/udstunnel.conf'
LOGFORMAT = '%(levelname)s %(asctime)s %(message)s'
CONFIGFILE: typing.Final[str] = '/etc/udstunnel.conf' if not DEBUG else 'udstunnel.conf'
LOGFORMAT: typing.Final[str] = (
'%(levelname)s %(asctime)s %(message)s'
if not DEBUG
else '%(levelname)s %(asctime)s %(message)s'
)
# MAX Length of read buffer for proxyed requests
BUFFER_SIZE = 1024 * 16
BUFFER_SIZE: typing.Final[int] = 1024 * 16
# Handshake for conversation start
HANDSHAKE_V1 = b'\x5AMGB\xA5\x01\x00'
HANDSHAKE_V1: typing.Final[bytes] = b'\x5AMGB\xA5\x01\x00'
# Ticket length
TICKET_LENGTH = 48
TICKET_LENGTH: typing.Final[int] = 48
# Max Admin password length (stats basically right now)
PASSWORD_LENGTH = 64
PASSWORD_LENGTH: typing.Final[int] = 64
# Bandwidth calc time lapse
BANDWIDTH_TIME = 10
BANDWIDTH_TIME: typing.Final[int] = 10
# Commands LENGTH (all same length)
COMMAND_LENGTH = 4
COMMAND_LENGTH: typing.Final[int] = 4
VERSION = 'v2.0.0'
VERSION: typing.Final[str] = 'v2.0.0'
# Valid commands
COMMAND_OPEN = b'OPEN'
COMMAND_TEST = b'TEST'
COMMAND_STAT = b'STAT' # full stats
COMMAND_INFO = b'INFO' # Basic stats, currently same as FULL
COMMAND_OPEN: typing.Final[bytes] = b'OPEN'
COMMAND_TEST: typing.Final[bytes] = b'TEST'
COMMAND_STAT: typing.Final[bytes] = b'STAT' # full stats
COMMAND_INFO: typing.Final[bytes] = b'INFO' # Basic stats, currently same as FULL
RESPONSE_ERROR_TICKET: typing.Final[bytes] = b'ERROR_TICKET'
RESPONSE_ERROR_COMMAND: typing.Final[bytes] = b'ERROR_COMMAND'
RESPONSE_ERROR_TIMEOUT: typing.Final[bytes] = b'TIMEOUT'
RESPONSE_FORBIDDEN: typing.Final[bytes] = b'FORBIDDEN'
RESPONSE_OK: typing.Final[bytes] = b'OK'
# Timeout for command
TIMEOUT_COMMAND: typing.Final[int] = 3
# Backlog for listen socket
BACKLOG = 1024

View File

@ -156,13 +156,16 @@ class Processes:
ns: 'Namespace',
) -> None:
if cfg.use_uvloop:
import uvloop
try:
import uvloop
if sys.version_info >= (3, 11):
with asyncio.Runner(loop_factory=uvloop.new_event_loop) as runner:
runner.run(proc(conn, cfg, ns))
else:
uvloop.install()
asyncio.run(proc(conn, cfg, ns))
if sys.version_info >= (3, 11):
with asyncio.Runner(loop_factory=uvloop.new_event_loop) as runner:
runner.run(proc(conn, cfg, ns))
else:
uvloop.install()
asyncio.run(proc(conn, cfg, ns))
except ImportError:
logger.warning('uvloop not found, using default asyncio')
else:
asyncio.run(proc(conn, cfg, ns))

View File

@ -46,6 +46,7 @@ logger = logging.getLogger(__name__)
class Proxy:
cfg: 'config.ConfigurationType'
ns: 'Namespace'
finished: asyncio.Event
def __init__(self, cfg: 'config.ConfigurationType', ns: 'Namespace') -> None:
self.cfg = cfg
@ -63,22 +64,27 @@ class Proxy:
addr = source.getpeername()
except Exception:
addr = 'Unknown'
logger.error('Proxy error from %s: %s', addr, e)
logger.exception('Proxy error from %s: %s (%s--%s)', addr, e, source, context)
async def proxy(self, source: socket.socket, context: 'ssl.SSLContext') -> None:
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
# Handshake correct in this point, upgrade the connection to TSL and let
# the protocol controller do the rest
self.finished = asyncio.Event()
# Upgrade connection to SSL, and use asyncio to handle the rest
try:
protocol: tunnel.TunnelProtocol
# (connect accepted loop not present on AbastractEventLoop definition < 3.10)
(_, protocol) = await loop.connect_accepted_socket( # type: ignore
lambda: tunnel.TunnelProtocol(self), source, ssl=context
def factory() -> tunnel.TunnelProtocol:
return tunnel.TunnelProtocol(self)
# (connect accepted loop not present on AbastractEventLoop definition < 3.10), that's why we use ignore
await loop.connect_accepted_socket( # type: ignore
factory, source, ssl=context
)
await protocol.finished
# Wait for connection to be closed
await self.finished.wait()
except asyncio.CancelledError:
pass # Return on cancel

View File

@ -31,6 +31,7 @@ Author: Adolfo Gómez, dkmaster at dkmon dot com
import asyncio
import typing
import logging
import socket
import aiohttp
@ -46,8 +47,6 @@ if typing.TYPE_CHECKING:
# Protocol
class TunnelProtocol(asyncio.Protocol):
# future to mark eof
finished: asyncio.Future
# Transport and other side of tunnel
transport: 'asyncio.transports.Transport'
other_side: 'TunnelProtocol'
@ -56,7 +55,7 @@ class TunnelProtocol(asyncio.Protocol):
# Command buffer
cmd: bytes
# Ticket
notify_ticket: bytes
notify_ticket: bytes # Only exists on "slave" transport (that is, tunnel from us to remote machine)
# owner Proxy class
owner: 'proxy.Proxy'
# source of connection
@ -66,6 +65,8 @@ class TunnelProtocol(asyncio.Protocol):
stats_manager: stats.Stats
# counter
counter: stats.StatsSingleCounter
# If there is a timeout task running
timeout_task: typing.Optional[asyncio.Task] = None
def __init__(
self, owner: 'proxy.Proxy', other_side: typing.Optional['TunnelProtocol'] = None
@ -82,19 +83,23 @@ class TunnelProtocol(asyncio.Protocol):
self.stats_manager = stats.Stats(owner.ns)
self.counter = self.stats_manager.as_sent_counter()
self.runner = self.do_command
# Set starting timeout task, se we dont get hunged on connections without data
self.set_timeout(consts.TIMEOUT_COMMAND)
# transport is undefined until connection_made is called
self.finished = asyncio.Future()
self.cmd = b''
self.notify_ticket = b''
self.owner = owner
self.source = ('', 0)
self.destination = ('', 0)
def process_open(self) -> None:
# Open Command has the ticket behind it
if len(self.cmd) < consts.TICKET_LENGTH + consts.COMMAND_LENGTH:
# Reactivate timeout, will be deactivated on do_command
self.set_timeout(consts.TIMEOUT_COMMAND)
return # Wait for more data to complete OPEN command
# Ticket received, now process it with UDS
@ -106,7 +111,7 @@ class TunnelProtocol(asyncio.Protocol):
# clean up the command
self.cmd = b''
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
async def open_other_side() -> None:
try:
@ -115,7 +120,7 @@ class TunnelProtocol(asyncio.Protocol):
)
except Exception as e:
logger.error('ERROR %s', e.args[0] if e.args else e)
self.transport.write(b'ERROR_TICKET')
self.transport.write(consts.RESPONSE_ERROR_TICKET)
self.transport.close() # And force close
return
@ -130,10 +135,12 @@ class TunnelProtocol(asyncio.Protocol):
)
try:
family = socket.AF_INET6 if ':' in self.destination[0] or self.owner.cfg.ipv6 else socket.AF_INET
(_, protocol) = await loop.create_connection(
lambda: TunnelProtocol(self.owner, self),
self.destination[0],
self.destination[1],
family=family,
)
self.other_side = typing.cast('TunnelProtocol', protocol)
@ -145,6 +152,7 @@ class TunnelProtocol(asyncio.Protocol):
logger.error('Error opening connection: %s', e)
self.close_connection()
# add open other side to the loop
loop.create_task(open_other_side())
# From now, proxy connection
self.runner = self.do_proxy
@ -160,7 +168,7 @@ class TunnelProtocol(asyncio.Protocol):
# Check valid source ip
if self.transport.get_extra_info('peername')[0] not in self.owner.cfg.allow:
# Invalid source
self.transport.write(b'FORBIDDEN')
self.transport.write(consts.RESPONSE_FORBIDDEN)
return
# Check password
@ -171,7 +179,7 @@ class TunnelProtocol(asyncio.Protocol):
if passwd.decode(errors='ignore') != self.owner.cfg.secret:
# Invalid password
self.transport.write(b'FORBIDDEN')
self.transport.write(consts.RESPONSE_FORBIDDEN)
return
data = stats.GlobalStats.get_stats(self.owner.ns)
@ -184,18 +192,50 @@ class TunnelProtocol(asyncio.Protocol):
finally:
self.close_connection()
async def timeout(self, wait: int) -> None:
""" Timeout can only occur while waiting for a command."""
try:
await asyncio.sleep(wait)
logger.error('TIMEOUT FROM %s', self.pretty_source())
self.transport.write(consts.RESPONSE_ERROR_TIMEOUT)
self.close_connection()
except asyncio.CancelledError:
pass
def set_timeout(self, wait: int) -> None:
"""Set a timeout for this connection.
If reached, the connection will be closed.
Args:
wait (int): Timeout in seconds
"""
if self.timeout_task:
self.timeout_task.cancel()
self.timeout_task = asyncio.create_task(self.timeout(wait))
def clean_timeout(self) -> None:
"""Clean the timeout task if any.
"""
if self.timeout_task:
self.timeout_task.cancel()
self.timeout_task = None
def do_command(self, data: bytes) -> None:
self.cmd += data
if len(self.cmd) >= consts.COMMAND_LENGTH:
if self.cmd == b'':
logger.info('CONNECT FROM %s', self.pretty_source())
self.clean_timeout()
self.cmd += data
# Ensure we don't get a timeout
if len(self.cmd) >= consts.COMMAND_LENGTH:
command = self.cmd[: consts.COMMAND_LENGTH]
try:
if command == consts.COMMAND_OPEN:
self.process_open()
elif command == consts.COMMAND_TEST:
logger.info('COMMAND: TEST')
self.transport.write(b'OK')
self.transport.write(consts.RESPONSE_OK)
self.close_connection()
return
elif command in (consts.COMMAND_STAT, consts.COMMAND_INFO):
@ -206,9 +246,11 @@ class TunnelProtocol(asyncio.Protocol):
raise Exception('Invalid command')
except Exception:
logger.error('ERROR from %s', self.pretty_source())
self.transport.write(b'ERROR_COMMAND')
self.transport.write(consts.RESPONSE_ERROR_COMMAND)
self.close_connection()
return
else:
self.set_timeout(consts.TIMEOUT_COMMAND)
# if not enough data to process command, wait for more
@ -221,6 +263,7 @@ class TunnelProtocol(asyncio.Protocol):
def connection_made(self, transport: 'asyncio.transports.BaseTransport') -> None:
logger.debug('Connection made: %s', transport.get_extra_info('peername'))
self.main = True # This is the main connection
# We know for sure that the transport is a Transport.
self.transport = typing.cast('asyncio.transports.Transport', transport)
@ -239,10 +282,12 @@ class TunnelProtocol(asyncio.Protocol):
)
)
self.notify_ticket = b'' # Clean up so no more notifications
else: # No ticket, this is "main" connection (from client to us). Notify owner that we are done
self.owner.finished.set()
def connection_lost(self, exc: typing.Optional[Exception]) -> None:
logger.debug('Connection closed : %s', exc)
self.finished.set_result(True)
# Ensure close other side if any
if self.other_side is not self:
self.other_side.transport.close()
else:
@ -250,12 +295,17 @@ class TunnelProtocol(asyncio.Protocol):
self.notifyEnd()
# helpers
@staticmethod
def pretty_address(address: typing.Tuple[str, int]) -> str:
if ':' in address[0]:
return '[' + address[0] + ']:' + str(address[1])
return address[0] + ':' + str(address[1])
# source address, pretty format
def pretty_source(self) -> str:
return self.source[0] + ':' + str(self.source[1])
return TunnelProtocol.pretty_address(self.source)
def pretty_destination(self) -> str:
return self.destination[0] + ':' + str(self.destination[1])
return TunnelProtocol.pretty_address(self.destination)
def close_connection(self):
self.transport.close()

View File

@ -20,14 +20,21 @@ lognumber = 3
# Listen address. Defaults to 0.0.0.0
address = 0.0.0.0
# Number of workers. Defaults to 0 (means "as much as cores")
workers = 2
# Listening port
port = 7777
# If force ipv6, defaults to false
# Note: if listen address is an ipv6 address, this will be forced to true
# This will force dns resolution to ipv6
ipv6 = false
# Number of workers. Defaults to 0 (means "as much as cores")
workers = 2
# SSL Related parameters.
ssl_certificate = /etc/certs/server.pem
# Key can be included on certificate file, so this is optional
ssl_certificate_key = /etc/certs/key.pem
# ssl_ciphers and ssl_dhparam are optional.
ssl_ciphers = ECDHE-RSA-AES256-GCM-SHA512:DHE-RSA-AES256-GCM-SHA512:ECDHE-RSA-AES256-GCM-SHA384:DHE-RSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-SHA384
@ -40,6 +47,11 @@ ssl_dhparam = /etc/certs/dhparam.pem
# https://www.example.com:14333/uds/rest/tunnel/ticket
uds_server = http://172.27.0.1:8000/uds/rest/tunnel/ticket
uds_token = eBCeFxTBw1IKXCqq-RlncshwWIfrrqxc8y5nehqiqMtRztwD
# Defaults to 10 seconds
# uds_timeout = 10
# If verify ssl certificate on uds server. Defaults to true
# uds_verify_ssl = true
# Secret to get access to admin commands (Currently only stats commands). No default for this.
# Admin commands and only allowed from "allow" ips
@ -50,3 +62,7 @@ secret = MySecret
# Only use IPs, no networks allowed
# defaults to localhost (change if listen address is different from 0.0.0.0)
allow = 127.0.0.1
# If use uvloop as event loop. Defaults to true
# use_uvloop = true

View File

@ -39,6 +39,8 @@ import ssl
import socket
import logging
from concurrent.futures import ThreadPoolExecutor
# event for stop notification
import threading
import typing
try:
@ -59,16 +61,15 @@ if typing.TYPE_CHECKING:
from multiprocessing.connection import Connection
from multiprocessing.managers import Namespace
BACKLOG = 1024
logger = logging.getLogger(__name__)
do_stop = False
running: threading.Event = threading.Event()
def stop_signal(signum: int, frame: typing.Any) -> None:
global do_stop
do_stop = True
global running
running.clear()
logger.debug('SIGNAL %s, frame: %s', signum, frame)
@ -76,26 +77,26 @@ def setup_log(cfg: config.ConfigurationType) -> None:
from logging.handlers import RotatingFileHandler
# Update logging if needed
if cfg.log_file:
if cfg.logfile:
fileh = RotatingFileHandler(
filename=cfg.log_file,
filename=cfg.logfile,
mode='a',
maxBytes=cfg.log_size,
backupCount=cfg.log_number,
maxBytes=cfg.logsize,
backupCount=cfg.lognumber,
)
formatter = logging.Formatter(consts.LOGFORMAT)
fileh.setFormatter(formatter)
log = logging.getLogger()
log.setLevel(cfg.log_level)
log.setLevel(cfg.loglevel)
# for hdlr in log.handlers[:]:
# log.removeHandler(hdlr)
log.addHandler(fileh)
else:
# Setup basic logging
log = logging.getLogger()
log.setLevel(cfg.log_level)
log.setLevel(cfg.loglevel)
handler = logging.StreamHandler(sys.stderr)
handler.setLevel(cfg.log_level)
handler.setLevel(cfg.loglevel)
formatter = logging.Formatter(
'%(levelname)s - %(message)s'
) # Basic log format, nice for syslog
@ -107,10 +108,7 @@ async def tunnel_proc_async(
pipe: 'Connection', cfg: config.ConfigurationType, ns: 'Namespace'
) -> None:
try:
loop = asyncio.get_running_loop()
except RuntimeError: # older python versions
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
tasks: typing.List[asyncio.Task] = []
@ -123,8 +121,13 @@ async def tunnel_proc_async(
] = pipe.recv()
if msg:
return msg
except EOFError:
logger.debug('Parent process closed connection')
pipe.close()
return None, None
except Exception:
logger.exception('Receiving data from parent process')
pipe.close()
return None, None
async def run_server() -> None:
@ -133,7 +136,15 @@ async def tunnel_proc_async(
# Generate SSL context
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
context.load_cert_chain(cfg.ssl_certificate, cfg.ssl_certificate_key)
args: typing.Dict[str, typing.Any] = {
'certfile': cfg.ssl_certificate,
}
if cfg.ssl_certificate_key:
args['keyfile'] = cfg.ssl_certificate_key
if cfg.ssl_password:
args['password'] = cfg.ssl_password
context.load_cert_chain(**args)
if cfg.ssl_ciphers:
context.set_ciphers(cfg.ssl_ciphers)
@ -141,29 +152,44 @@ async def tunnel_proc_async(
if cfg.ssl_dhparam:
context.load_dh_params(cfg.ssl_dhparam)
while True:
address: typing.Optional[typing.Tuple[str, int]] = ('', 0)
try:
(sock, address) = await loop.run_in_executor(None, get_socket)
if not sock:
break # No more sockets, exit
logger.debug(f'CONNECTION from {address!r} (pid: {os.getpid()})')
tasks.append(asyncio.create_task(tunneler(sock, context)))
except Exception:
logger.error('NEGOTIATION ERROR from %s', address[0] if address else 'unknown')
try:
while True:
address: typing.Optional[typing.Tuple[str, int]] = ('', 0)
try:
(sock, address) = await loop.run_in_executor(None, get_socket)
if not sock:
break # No more sockets, exit
logger.debug(f'CONNECTION from {address!r} (pid: {os.getpid()})')
tasks.append(asyncio.create_task(tunneler(sock, context)))
except asyncio.CancelledError:
raise
except Exception:
logger.error('NEGOTIATION ERROR from %s', address[0] if address else 'unknown')
except asyncio.CancelledError:
pass # Stop
# create task for server
tasks.append(asyncio.create_task(run_server()))
while tasks:
to_wait = tasks[:] # Get a copy of the list, and clean the original
# Wait for tasks to finish
done, _ = await asyncio.wait(to_wait, return_when=asyncio.FIRST_COMPLETED)
# Remove finished tasks
for task in done:
tasks.remove(task)
if task.exception():
logger.exception('TUNNEL ERROR')
try:
while tasks and running.is_set():
to_wait = tasks[:] # Get a copy of the list, and clean the original
# Wait for tasks to finish
done, _ = await asyncio.wait(to_wait, return_when=asyncio.FIRST_COMPLETED, timeout=2)
# Remove finished tasks
for task in done:
tasks.remove(task)
if task.exception():
logger.exception('TUNNEL ERROR')
except asyncio.CancelledError:
running.clear() # ensure we stop
# If any task is still running, cancel it
for task in tasks:
task.cancel()
# Wait for all tasks to finish
await asyncio.gather(*tasks, return_exceptions=True)
logger.info('PROCESS %s stopped', os.getpid())
@ -205,11 +231,11 @@ def tunnel_main(args: 'argparse.Namespace') -> None:
# logger.warning('socket.REUSEPORT not available')
try:
sock.bind((cfg.listen_address, cfg.listen_port))
sock.listen(BACKLOG)
sock.listen(consts.BACKLOG)
# If running as root, and requested drop privileges after port bind
if os.getuid() == 0 and cfg.user:
logger.debug('Changing to user %s', cfg.user)
logger.debug('Changing to user %s', cfg.user)
pwu = pwd.getpwnam(cfg.user)
# os.setgid(pwu.pw_gid)
os.setuid(pwu.pw_uid)
@ -233,16 +259,22 @@ def tunnel_main(args: 'argparse.Namespace') -> None:
return
# Setup signal handlers
signal.signal(signal.SIGINT, stop_signal)
signal.signal(signal.SIGTERM, stop_signal)
try:
signal.signal(signal.SIGINT, stop_signal)
signal.signal(signal.SIGTERM, stop_signal)
except Exception as e:
# Signal not available on threads, and we use threads on tests,
# so we will ignore this because on tests signals are not important
logger.warning('Signal not available: %s', e)
stats_collector = stats.GlobalStats()
prcs = processes.Processes(tunnel_proc_async, cfg, stats_collector.ns)
running.set() # Signal we are running
with ThreadPoolExecutor(max_workers=256) as executor:
try:
while not do_stop:
while running.is_set():
try:
client, addr = sock.accept()
logger.info('CONNECTION from %s', addr)