mirror of
https://github.com/dkmstr/openuds.git
synced 2025-03-20 06:50:23 +03:00
backported final fixes
This commit is contained in:
parent
dcedb268dd
commit
3e947e1d82
@ -66,11 +66,20 @@ class ConfigurationType(typing.NamedTuple):
|
||||
uds_timeout: int
|
||||
uds_verify_ssl: bool
|
||||
|
||||
command_timeout: int
|
||||
|
||||
secret: str
|
||||
allow: typing.Set[str]
|
||||
|
||||
use_uvloop: bool
|
||||
|
||||
def __str__(self) -> str:
|
||||
return 'Configuration: \n' + '\n'.join(
|
||||
f'{k}={v}'
|
||||
for k, v in self._asdict().items()
|
||||
)
|
||||
|
||||
|
||||
|
||||
def read_config_file(
|
||||
cfg_file: typing.Optional[typing.Union[typing.TextIO, str]] = None
|
||||
@ -131,6 +140,7 @@ def read(
|
||||
uds_token=uds.get('uds_token', 'unauthorized'),
|
||||
uds_timeout=int(uds.get('uds_timeout', '10')),
|
||||
uds_verify_ssl=uds.get('uds_verify_ssl', 'true').lower() == 'true',
|
||||
command_timeout=int(uds.get('command_timeout', '3')),
|
||||
secret=secret,
|
||||
allow=set(uds.get('allow', '127.0.0.1').split(',')),
|
||||
use_uvloop=uds.get('use_uvloop', 'true').lower() == 'true',
|
||||
|
@ -69,8 +69,5 @@ RESPONSE_FORBIDDEN: typing.Final[bytes] = b'FORBIDDEN'
|
||||
|
||||
RESPONSE_OK: typing.Final[bytes] = b'OK'
|
||||
|
||||
# Timeout for command
|
||||
TIMEOUT_COMMAND: typing.Final[int] = 3
|
||||
|
||||
# Backlog for listen socket
|
||||
BACKLOG = 1024
|
||||
|
@ -73,6 +73,13 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
) -> None:
|
||||
# If no other side is given, we are the server part
|
||||
super().__init__()
|
||||
# transport is undefined until connection_made is called
|
||||
self.cmd = b''
|
||||
self.notify_ticket = b''
|
||||
self.owner = owner
|
||||
self.source = ('', 0)
|
||||
self.destination = ('', 0)
|
||||
|
||||
if other_side:
|
||||
self.other_side = other_side
|
||||
self.stats_manager = other_side.stats_manager
|
||||
@ -84,21 +91,15 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
self.counter = self.stats_manager.as_sent_counter()
|
||||
self.runner = self.do_command
|
||||
# Set starting timeout task, se we dont get hunged on connections without data
|
||||
self.set_timeout(consts.TIMEOUT_COMMAND)
|
||||
self.set_timeout(self.owner.cfg.command_timeout)
|
||||
|
||||
# transport is undefined until connection_made is called
|
||||
self.cmd = b''
|
||||
self.notify_ticket = b''
|
||||
self.owner = owner
|
||||
self.source = ('', 0)
|
||||
self.destination = ('', 0)
|
||||
|
||||
def process_open(self) -> None:
|
||||
# Open Command has the ticket behind it
|
||||
|
||||
if len(self.cmd) < consts.TICKET_LENGTH + consts.COMMAND_LENGTH:
|
||||
# Reactivate timeout, will be deactivated on do_command
|
||||
self.set_timeout(consts.TIMEOUT_COMMAND)
|
||||
self.set_timeout(self.owner.cfg.command_timeout)
|
||||
return # Wait for more data to complete OPEN command
|
||||
|
||||
# Ticket received, now process it with UDS
|
||||
@ -253,7 +254,7 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
self.close_connection()
|
||||
return
|
||||
else:
|
||||
self.set_timeout(consts.TIMEOUT_COMMAND)
|
||||
self.set_timeout(self.owner.cfg.command_timeout)
|
||||
|
||||
# if not enough data to process command, wait for more
|
||||
|
||||
@ -289,7 +290,6 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
self.owner.finished.set()
|
||||
|
||||
def connection_lost(self, exc: typing.Optional[Exception]) -> None:
|
||||
logger.debug('Connection closed : %s', exc)
|
||||
# Ensure close other side if any
|
||||
if self.other_side is not self:
|
||||
self.other_side.transport.close()
|
||||
|
@ -53,6 +53,10 @@ uds_token = eBCeFxTBw1IKXCqq-RlncshwWIfrrqxc8y5nehqiqMtRztwD
|
||||
# If verify ssl certificate on uds server. Defaults to true
|
||||
# uds_verify_ssl = true
|
||||
|
||||
# Command timeout. Command reception on tunnel will timeout after this time (in seconds)
|
||||
# defaults to 3 seconds
|
||||
# command_timeout = 3
|
||||
|
||||
# Secret to get access to admin commands (Currently only stats commands). No default for this.
|
||||
# Admin commands and only allowed from "allow" ips
|
||||
# So, in order to allow this commands, ensure listen address allows connections from localhost
|
||||
|
@ -102,6 +102,10 @@ def setup_log(cfg: config.ConfigurationType) -> None:
|
||||
handler.setFormatter(formatter)
|
||||
log.addHandler(handler)
|
||||
|
||||
# If debug, print config
|
||||
if cfg.loglevel.lower() == 'debug':
|
||||
logger.debug('Configuration: %s', cfg)
|
||||
|
||||
|
||||
async def tunnel_proc_async(
|
||||
pipe: 'Connection', cfg: config.ConfigurationType, ns: 'Namespace'
|
||||
@ -111,6 +115,10 @@ async def tunnel_proc_async(
|
||||
|
||||
tasks: typing.List[asyncio.Task] = []
|
||||
|
||||
def add_autoremovable_task(task: asyncio.Task) -> None:
|
||||
tasks.append(task)
|
||||
task.add_done_callback(tasks.remove)
|
||||
|
||||
def get_socket() -> typing.Tuple[typing.Optional[socket.socket], typing.Optional[typing.Tuple[str, int]]]:
|
||||
try:
|
||||
while True:
|
||||
@ -157,7 +165,7 @@ async def tunnel_proc_async(
|
||||
break # No more sockets, exit
|
||||
logger.debug(f'CONNECTION from {address!r} (pid: {os.getpid()})')
|
||||
# Due to proxy contains an "event" to stop, we need to create a new one for each connection
|
||||
tasks.append(asyncio.create_task(proxy.Proxy(cfg, ns)(sock, context)))
|
||||
add_autoremovable_task(asyncio.create_task(proxy.Proxy(cfg, ns)(sock, context)))
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
@ -166,23 +174,20 @@ async def tunnel_proc_async(
|
||||
pass # Stop
|
||||
|
||||
# create task for server
|
||||
tasks.append(asyncio.create_task(run_server()))
|
||||
|
||||
add_autoremovable_task(asyncio.create_task(run_server()))
|
||||
|
||||
try:
|
||||
while tasks and not do_stop.is_set():
|
||||
to_wait = tasks[:] # Get a copy of the list
|
||||
# Wait for "to_wait" tasks to finish, stop every 2 seconds to check if we need to stop
|
||||
done, _ = await asyncio.wait(to_wait, return_when=asyncio.FIRST_COMPLETED, timeout=2)
|
||||
# Remove finished tasks
|
||||
for task in done:
|
||||
tasks.remove(task)
|
||||
if task.exception():
|
||||
logger.exception('TUNNEL ERROR')
|
||||
except asyncio.CancelledError:
|
||||
logger.info('Task cancelled')
|
||||
do_stop.set() # ensure we stop
|
||||
|
||||
logger.debug('Out of loop, stopping tasks: %s, running: %s', tasks, do_stop.is_set())
|
||||
|
||||
# If any task is still running, cancel it
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
|
Loading…
x
Reference in New Issue
Block a user