1
0
mirror of https://github.com/dkmstr/openuds.git synced 2025-03-20 06:50:23 +03:00

backported final fixes

This commit is contained in:
Adolfo Gómez García 2022-12-22 15:16:01 +01:00
parent dcedb268dd
commit 3e947e1d82
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
5 changed files with 36 additions and 20 deletions

View File

@ -66,11 +66,20 @@ class ConfigurationType(typing.NamedTuple):
uds_timeout: int
uds_verify_ssl: bool
command_timeout: int
secret: str
allow: typing.Set[str]
use_uvloop: bool
def __str__(self) -> str:
return 'Configuration: \n' + '\n'.join(
f'{k}={v}'
for k, v in self._asdict().items()
)
def read_config_file(
cfg_file: typing.Optional[typing.Union[typing.TextIO, str]] = None
@ -131,6 +140,7 @@ def read(
uds_token=uds.get('uds_token', 'unauthorized'),
uds_timeout=int(uds.get('uds_timeout', '10')),
uds_verify_ssl=uds.get('uds_verify_ssl', 'true').lower() == 'true',
command_timeout=int(uds.get('command_timeout', '3')),
secret=secret,
allow=set(uds.get('allow', '127.0.0.1').split(',')),
use_uvloop=uds.get('use_uvloop', 'true').lower() == 'true',

View File

@ -69,8 +69,5 @@ RESPONSE_FORBIDDEN: typing.Final[bytes] = b'FORBIDDEN'
RESPONSE_OK: typing.Final[bytes] = b'OK'
# Timeout for command
TIMEOUT_COMMAND: typing.Final[int] = 3
# Backlog for listen socket
BACKLOG = 1024

View File

@ -73,6 +73,13 @@ class TunnelProtocol(asyncio.Protocol):
) -> None:
# If no other side is given, we are the server part
super().__init__()
# transport is undefined until connection_made is called
self.cmd = b''
self.notify_ticket = b''
self.owner = owner
self.source = ('', 0)
self.destination = ('', 0)
if other_side:
self.other_side = other_side
self.stats_manager = other_side.stats_manager
@ -84,21 +91,15 @@ class TunnelProtocol(asyncio.Protocol):
self.counter = self.stats_manager.as_sent_counter()
self.runner = self.do_command
# Set starting timeout task, se we dont get hunged on connections without data
self.set_timeout(consts.TIMEOUT_COMMAND)
self.set_timeout(self.owner.cfg.command_timeout)
# transport is undefined until connection_made is called
self.cmd = b''
self.notify_ticket = b''
self.owner = owner
self.source = ('', 0)
self.destination = ('', 0)
def process_open(self) -> None:
# Open Command has the ticket behind it
if len(self.cmd) < consts.TICKET_LENGTH + consts.COMMAND_LENGTH:
# Reactivate timeout, will be deactivated on do_command
self.set_timeout(consts.TIMEOUT_COMMAND)
self.set_timeout(self.owner.cfg.command_timeout)
return # Wait for more data to complete OPEN command
# Ticket received, now process it with UDS
@ -253,7 +254,7 @@ class TunnelProtocol(asyncio.Protocol):
self.close_connection()
return
else:
self.set_timeout(consts.TIMEOUT_COMMAND)
self.set_timeout(self.owner.cfg.command_timeout)
# if not enough data to process command, wait for more
@ -289,7 +290,6 @@ class TunnelProtocol(asyncio.Protocol):
self.owner.finished.set()
def connection_lost(self, exc: typing.Optional[Exception]) -> None:
logger.debug('Connection closed : %s', exc)
# Ensure close other side if any
if self.other_side is not self:
self.other_side.transport.close()

View File

@ -53,6 +53,10 @@ uds_token = eBCeFxTBw1IKXCqq-RlncshwWIfrrqxc8y5nehqiqMtRztwD
# If verify ssl certificate on uds server. Defaults to true
# uds_verify_ssl = true
# Command timeout. Command reception on tunnel will timeout after this time (in seconds)
# defaults to 3 seconds
# command_timeout = 3
# Secret to get access to admin commands (Currently only stats commands). No default for this.
# Admin commands and only allowed from "allow" ips
# So, in order to allow this commands, ensure listen address allows connections from localhost

View File

@ -102,6 +102,10 @@ def setup_log(cfg: config.ConfigurationType) -> None:
handler.setFormatter(formatter)
log.addHandler(handler)
# If debug, print config
if cfg.loglevel.lower() == 'debug':
logger.debug('Configuration: %s', cfg)
async def tunnel_proc_async(
pipe: 'Connection', cfg: config.ConfigurationType, ns: 'Namespace'
@ -111,6 +115,10 @@ async def tunnel_proc_async(
tasks: typing.List[asyncio.Task] = []
def add_autoremovable_task(task: asyncio.Task) -> None:
tasks.append(task)
task.add_done_callback(tasks.remove)
def get_socket() -> typing.Tuple[typing.Optional[socket.socket], typing.Optional[typing.Tuple[str, int]]]:
try:
while True:
@ -157,7 +165,7 @@ async def tunnel_proc_async(
break # No more sockets, exit
logger.debug(f'CONNECTION from {address!r} (pid: {os.getpid()})')
# Due to proxy contains an "event" to stop, we need to create a new one for each connection
tasks.append(asyncio.create_task(proxy.Proxy(cfg, ns)(sock, context)))
add_autoremovable_task(asyncio.create_task(proxy.Proxy(cfg, ns)(sock, context)))
except asyncio.CancelledError:
raise
except Exception:
@ -166,23 +174,20 @@ async def tunnel_proc_async(
pass # Stop
# create task for server
tasks.append(asyncio.create_task(run_server()))
add_autoremovable_task(asyncio.create_task(run_server()))
try:
while tasks and not do_stop.is_set():
to_wait = tasks[:] # Get a copy of the list
# Wait for "to_wait" tasks to finish, stop every 2 seconds to check if we need to stop
done, _ = await asyncio.wait(to_wait, return_when=asyncio.FIRST_COMPLETED, timeout=2)
# Remove finished tasks
for task in done:
tasks.remove(task)
if task.exception():
logger.exception('TUNNEL ERROR')
except asyncio.CancelledError:
logger.info('Task cancelled')
do_stop.set() # ensure we stop
logger.debug('Out of loop, stopping tasks: %s, running: %s', tasks, do_stop.is_set())
# If any task is still running, cancel it
for task in tasks:
task.cancel()