mirror of
https://github.com/dkmstr/openuds.git
synced 2025-08-24 09:49:52 +03:00
backported final fixes
This commit is contained in:
@ -66,11 +66,20 @@ class ConfigurationType(typing.NamedTuple):
|
|||||||
uds_timeout: int
|
uds_timeout: int
|
||||||
uds_verify_ssl: bool
|
uds_verify_ssl: bool
|
||||||
|
|
||||||
|
command_timeout: int
|
||||||
|
|
||||||
secret: str
|
secret: str
|
||||||
allow: typing.Set[str]
|
allow: typing.Set[str]
|
||||||
|
|
||||||
use_uvloop: bool
|
use_uvloop: bool
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return 'Configuration: \n' + '\n'.join(
|
||||||
|
f'{k}={v}'
|
||||||
|
for k, v in self._asdict().items()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def read_config_file(
|
def read_config_file(
|
||||||
cfg_file: typing.Optional[typing.Union[typing.TextIO, str]] = None
|
cfg_file: typing.Optional[typing.Union[typing.TextIO, str]] = None
|
||||||
@ -131,6 +140,7 @@ def read(
|
|||||||
uds_token=uds.get('uds_token', 'unauthorized'),
|
uds_token=uds.get('uds_token', 'unauthorized'),
|
||||||
uds_timeout=int(uds.get('uds_timeout', '10')),
|
uds_timeout=int(uds.get('uds_timeout', '10')),
|
||||||
uds_verify_ssl=uds.get('uds_verify_ssl', 'true').lower() == 'true',
|
uds_verify_ssl=uds.get('uds_verify_ssl', 'true').lower() == 'true',
|
||||||
|
command_timeout=int(uds.get('command_timeout', '3')),
|
||||||
secret=secret,
|
secret=secret,
|
||||||
allow=set(uds.get('allow', '127.0.0.1').split(',')),
|
allow=set(uds.get('allow', '127.0.0.1').split(',')),
|
||||||
use_uvloop=uds.get('use_uvloop', 'true').lower() == 'true',
|
use_uvloop=uds.get('use_uvloop', 'true').lower() == 'true',
|
||||||
|
@ -69,8 +69,5 @@ RESPONSE_FORBIDDEN: typing.Final[bytes] = b'FORBIDDEN'
|
|||||||
|
|
||||||
RESPONSE_OK: typing.Final[bytes] = b'OK'
|
RESPONSE_OK: typing.Final[bytes] = b'OK'
|
||||||
|
|
||||||
# Timeout for command
|
|
||||||
TIMEOUT_COMMAND: typing.Final[int] = 3
|
|
||||||
|
|
||||||
# Backlog for listen socket
|
# Backlog for listen socket
|
||||||
BACKLOG = 1024
|
BACKLOG = 1024
|
||||||
|
@ -73,6 +73,13 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
) -> None:
|
) -> None:
|
||||||
# If no other side is given, we are the server part
|
# If no other side is given, we are the server part
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
# transport is undefined until connection_made is called
|
||||||
|
self.cmd = b''
|
||||||
|
self.notify_ticket = b''
|
||||||
|
self.owner = owner
|
||||||
|
self.source = ('', 0)
|
||||||
|
self.destination = ('', 0)
|
||||||
|
|
||||||
if other_side:
|
if other_side:
|
||||||
self.other_side = other_side
|
self.other_side = other_side
|
||||||
self.stats_manager = other_side.stats_manager
|
self.stats_manager = other_side.stats_manager
|
||||||
@ -84,21 +91,15 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
self.counter = self.stats_manager.as_sent_counter()
|
self.counter = self.stats_manager.as_sent_counter()
|
||||||
self.runner = self.do_command
|
self.runner = self.do_command
|
||||||
# Set starting timeout task, se we dont get hunged on connections without data
|
# Set starting timeout task, se we dont get hunged on connections without data
|
||||||
self.set_timeout(consts.TIMEOUT_COMMAND)
|
self.set_timeout(self.owner.cfg.command_timeout)
|
||||||
|
|
||||||
# transport is undefined until connection_made is called
|
|
||||||
self.cmd = b''
|
|
||||||
self.notify_ticket = b''
|
|
||||||
self.owner = owner
|
|
||||||
self.source = ('', 0)
|
|
||||||
self.destination = ('', 0)
|
|
||||||
|
|
||||||
def process_open(self) -> None:
|
def process_open(self) -> None:
|
||||||
# Open Command has the ticket behind it
|
# Open Command has the ticket behind it
|
||||||
|
|
||||||
if len(self.cmd) < consts.TICKET_LENGTH + consts.COMMAND_LENGTH:
|
if len(self.cmd) < consts.TICKET_LENGTH + consts.COMMAND_LENGTH:
|
||||||
# Reactivate timeout, will be deactivated on do_command
|
# Reactivate timeout, will be deactivated on do_command
|
||||||
self.set_timeout(consts.TIMEOUT_COMMAND)
|
self.set_timeout(self.owner.cfg.command_timeout)
|
||||||
return # Wait for more data to complete OPEN command
|
return # Wait for more data to complete OPEN command
|
||||||
|
|
||||||
# Ticket received, now process it with UDS
|
# Ticket received, now process it with UDS
|
||||||
@ -253,7 +254,7 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
self.close_connection()
|
self.close_connection()
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
self.set_timeout(consts.TIMEOUT_COMMAND)
|
self.set_timeout(self.owner.cfg.command_timeout)
|
||||||
|
|
||||||
# if not enough data to process command, wait for more
|
# if not enough data to process command, wait for more
|
||||||
|
|
||||||
@ -289,7 +290,6 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
self.owner.finished.set()
|
self.owner.finished.set()
|
||||||
|
|
||||||
def connection_lost(self, exc: typing.Optional[Exception]) -> None:
|
def connection_lost(self, exc: typing.Optional[Exception]) -> None:
|
||||||
logger.debug('Connection closed : %s', exc)
|
|
||||||
# Ensure close other side if any
|
# Ensure close other side if any
|
||||||
if self.other_side is not self:
|
if self.other_side is not self:
|
||||||
self.other_side.transport.close()
|
self.other_side.transport.close()
|
||||||
|
@ -53,6 +53,10 @@ uds_token = eBCeFxTBw1IKXCqq-RlncshwWIfrrqxc8y5nehqiqMtRztwD
|
|||||||
# If verify ssl certificate on uds server. Defaults to true
|
# If verify ssl certificate on uds server. Defaults to true
|
||||||
# uds_verify_ssl = true
|
# uds_verify_ssl = true
|
||||||
|
|
||||||
|
# Command timeout. Command reception on tunnel will timeout after this time (in seconds)
|
||||||
|
# defaults to 3 seconds
|
||||||
|
# command_timeout = 3
|
||||||
|
|
||||||
# Secret to get access to admin commands (Currently only stats commands). No default for this.
|
# Secret to get access to admin commands (Currently only stats commands). No default for this.
|
||||||
# Admin commands and only allowed from "allow" ips
|
# Admin commands and only allowed from "allow" ips
|
||||||
# So, in order to allow this commands, ensure listen address allows connections from localhost
|
# So, in order to allow this commands, ensure listen address allows connections from localhost
|
||||||
|
@ -102,6 +102,10 @@ def setup_log(cfg: config.ConfigurationType) -> None:
|
|||||||
handler.setFormatter(formatter)
|
handler.setFormatter(formatter)
|
||||||
log.addHandler(handler)
|
log.addHandler(handler)
|
||||||
|
|
||||||
|
# If debug, print config
|
||||||
|
if cfg.loglevel.lower() == 'debug':
|
||||||
|
logger.debug('Configuration: %s', cfg)
|
||||||
|
|
||||||
|
|
||||||
async def tunnel_proc_async(
|
async def tunnel_proc_async(
|
||||||
pipe: 'Connection', cfg: config.ConfigurationType, ns: 'Namespace'
|
pipe: 'Connection', cfg: config.ConfigurationType, ns: 'Namespace'
|
||||||
@ -111,6 +115,10 @@ async def tunnel_proc_async(
|
|||||||
|
|
||||||
tasks: typing.List[asyncio.Task] = []
|
tasks: typing.List[asyncio.Task] = []
|
||||||
|
|
||||||
|
def add_autoremovable_task(task: asyncio.Task) -> None:
|
||||||
|
tasks.append(task)
|
||||||
|
task.add_done_callback(tasks.remove)
|
||||||
|
|
||||||
def get_socket() -> typing.Tuple[typing.Optional[socket.socket], typing.Optional[typing.Tuple[str, int]]]:
|
def get_socket() -> typing.Tuple[typing.Optional[socket.socket], typing.Optional[typing.Tuple[str, int]]]:
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
@ -157,7 +165,7 @@ async def tunnel_proc_async(
|
|||||||
break # No more sockets, exit
|
break # No more sockets, exit
|
||||||
logger.debug(f'CONNECTION from {address!r} (pid: {os.getpid()})')
|
logger.debug(f'CONNECTION from {address!r} (pid: {os.getpid()})')
|
||||||
# Due to proxy contains an "event" to stop, we need to create a new one for each connection
|
# Due to proxy contains an "event" to stop, we need to create a new one for each connection
|
||||||
tasks.append(asyncio.create_task(proxy.Proxy(cfg, ns)(sock, context)))
|
add_autoremovable_task(asyncio.create_task(proxy.Proxy(cfg, ns)(sock, context)))
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
raise
|
raise
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -166,23 +174,20 @@ async def tunnel_proc_async(
|
|||||||
pass # Stop
|
pass # Stop
|
||||||
|
|
||||||
# create task for server
|
# create task for server
|
||||||
tasks.append(asyncio.create_task(run_server()))
|
|
||||||
|
add_autoremovable_task(asyncio.create_task(run_server()))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while tasks and not do_stop.is_set():
|
while tasks and not do_stop.is_set():
|
||||||
to_wait = tasks[:] # Get a copy of the list
|
to_wait = tasks[:] # Get a copy of the list
|
||||||
# Wait for "to_wait" tasks to finish, stop every 2 seconds to check if we need to stop
|
# Wait for "to_wait" tasks to finish, stop every 2 seconds to check if we need to stop
|
||||||
done, _ = await asyncio.wait(to_wait, return_when=asyncio.FIRST_COMPLETED, timeout=2)
|
done, _ = await asyncio.wait(to_wait, return_when=asyncio.FIRST_COMPLETED, timeout=2)
|
||||||
# Remove finished tasks
|
|
||||||
for task in done:
|
|
||||||
tasks.remove(task)
|
|
||||||
if task.exception():
|
|
||||||
logger.exception('TUNNEL ERROR')
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.info('Task cancelled')
|
logger.info('Task cancelled')
|
||||||
do_stop.set() # ensure we stop
|
do_stop.set() # ensure we stop
|
||||||
|
|
||||||
logger.debug('Out of loop, stopping tasks: %s, running: %s', tasks, do_stop.is_set())
|
logger.debug('Out of loop, stopping tasks: %s, running: %s', tasks, do_stop.is_set())
|
||||||
|
|
||||||
# If any task is still running, cancel it
|
# If any task is still running, cancel it
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
Reference in New Issue
Block a user