mirror of
https://github.com/dkmstr/openuds.git
synced 2024-12-22 13:34:04 +03:00
upgrading and linting tunnel
This commit is contained in:
parent
ddf07eb68b
commit
4d26df9580
@ -28,6 +28,7 @@
|
||||
'''
|
||||
Author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
'''
|
||||
import re
|
||||
import typing
|
||||
|
||||
DEBUG = True
|
||||
@ -71,3 +72,6 @@ RESPONSE_OK: typing.Final[bytes] = b'OK'
|
||||
|
||||
# Backlog for listen socket
|
||||
BACKLOG = 1024
|
||||
|
||||
# Regular expression for parsing ticket
|
||||
TICKET_REGEX = re.compile(f'^[a-zA-Z0-9]{{{TICKET_LENGTH}}}$')
|
||||
|
@ -78,9 +78,7 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
# If there is a timeout task running
|
||||
timeout_task: typing.Optional[asyncio.Task] = None
|
||||
|
||||
def __init__(
|
||||
self, owner: 'proxy.Proxy'
|
||||
) -> None:
|
||||
def __init__(self, owner: 'proxy.Proxy') -> None:
|
||||
# If no other side is given, we are the server part
|
||||
super().__init__()
|
||||
# transport is undefined until connection_made is called
|
||||
@ -124,9 +122,7 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
|
||||
async def open_client() -> None:
|
||||
try:
|
||||
result = await TunnelProtocol.get_ticket_from_uds(
|
||||
self.owner.cfg, ticket, self.source
|
||||
)
|
||||
result = await TunnelProtocol.get_ticket_from_uds(self.owner.cfg, ticket, self.source)
|
||||
except Exception as e:
|
||||
logger.error('ERROR %s', e.args[0] if e.args else e)
|
||||
self.transport.write(consts.RESPONSE_ERROR_TICKET)
|
||||
@ -146,8 +142,7 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
try:
|
||||
family = (
|
||||
socket.AF_INET6
|
||||
if ':' in self.destination[0]
|
||||
or (self.owner.cfg.ipv6 and not '.' in self.destination[0])
|
||||
if ':' in self.destination[0] or (self.owner.cfg.ipv6 and '.' not in self.destination[0])
|
||||
else socket.AF_INET
|
||||
)
|
||||
(_, self.client) = await loop.create_connection(
|
||||
@ -171,7 +166,7 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
# From now, proxy connection
|
||||
self.runner = self.do_proxy
|
||||
|
||||
def process_stats(self, full: bool) -> None:
|
||||
def process_stats(self, full: bool) -> None: # pylint: disable=unused-argument
|
||||
# if pasword is not already received, wait for it
|
||||
if len(self.cmd) < consts.PASSWORD_LENGTH + consts.COMMAND_LENGTH:
|
||||
return
|
||||
@ -246,13 +241,14 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
try:
|
||||
if command == consts.COMMAND_OPEN:
|
||||
self.process_open()
|
||||
elif command == consts.COMMAND_TEST:
|
||||
return
|
||||
if command == consts.COMMAND_TEST:
|
||||
self.clean_timeout() # Stop timeout
|
||||
logger.info('COMMAND: TEST')
|
||||
self.transport.write(consts.RESPONSE_OK)
|
||||
self.close_connection()
|
||||
return
|
||||
elif command in (consts.COMMAND_STAT, consts.COMMAND_INFO):
|
||||
if command in (consts.COMMAND_STAT, consts.COMMAND_INFO):
|
||||
# This is an stats requests
|
||||
try:
|
||||
self.process_stats(full=command == consts.COMMAND_STAT)
|
||||
@ -260,7 +256,6 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
logger.error('ERROR processing stats: %s', e.args[0] if e.args else e)
|
||||
self.close_connection()
|
||||
return
|
||||
else:
|
||||
raise Exception('Invalid command')
|
||||
except Exception:
|
||||
logger.error('ERROR from %s', self.pretty_source())
|
||||
@ -298,9 +293,7 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
)
|
||||
# Notify end to uds, using a task becase we are not an async function
|
||||
asyncio.get_event_loop().create_task(
|
||||
TunnelProtocol.notify_end_to_uds(
|
||||
self.owner.cfg, self.notify_ticket, self.stats_manager
|
||||
)
|
||||
TunnelProtocol.notify_end_to_uds(self.owner.cfg, self.notify_ticket, self.stats_manager)
|
||||
)
|
||||
self.notify_ticket = b'' # Clean up so no more notifications
|
||||
else:
|
||||
@ -350,7 +343,6 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
def pretty_destination(self) -> str:
|
||||
return TunnelProtocol.pretty_address(self.destination)
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def _read_from_uds(
|
||||
cfg: config.ConfigurationType,
|
||||
@ -359,13 +351,9 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
queryParams: typing.Optional[typing.Mapping[str, str]] = None,
|
||||
) -> typing.MutableMapping[str, typing.Any]:
|
||||
try:
|
||||
url = (
|
||||
cfg.uds_server + '/' + ticket.decode() + '/' + msg + '/' + cfg.uds_token
|
||||
)
|
||||
url = cfg.uds_server + '/' + ticket.decode() + '/' + msg + '/' + cfg.uds_token
|
||||
if queryParams:
|
||||
url += '?' + '&'.join(
|
||||
[f'{key}={value}' for key, value in queryParams.items()]
|
||||
)
|
||||
url += '?' + '&'.join([f'{key}={value}' for key, value in queryParams.items()])
|
||||
# Set options
|
||||
options: typing.Dict[str, typing.Any] = {'timeout': cfg.uds_timeout}
|
||||
if cfg.uds_verify_ssl is False:
|
||||
@ -378,24 +366,15 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
raise Exception(await r.text())
|
||||
return await r.json()
|
||||
except Exception as e:
|
||||
raise Exception(f'TICKET COMMS ERROR: {ticket.decode()} {msg} {e!s}')
|
||||
raise Exception(f'TICKET COMMS ERROR: {ticket.decode()} {msg} {e!s}') from e
|
||||
|
||||
@staticmethod
|
||||
async def get_ticket_from_uds(
|
||||
cfg: config.ConfigurationType, ticket: bytes, address: typing.Tuple[str, int]
|
||||
) -> typing.MutableMapping[str, typing.Any]:
|
||||
# Sanity checks
|
||||
if len(ticket) != consts.TICKET_LENGTH:
|
||||
raise ValueError(f'TICKET INVALID (len={len(ticket)})')
|
||||
|
||||
for n, i in enumerate(ticket.decode(errors='ignore')):
|
||||
if (
|
||||
(i >= 'a' and i <= 'z')
|
||||
or (i >= '0' and i <= '9')
|
||||
or (i >= 'A' and i <= 'Z')
|
||||
):
|
||||
continue # Correctus
|
||||
raise ValueError(f'TICKET INVALID (char {i} at pos {n})')
|
||||
# Check ticket using re
|
||||
if consts.TICKET_REGEX.match(ticket.decode(errors='replace')) is None:
|
||||
raise ValueError(f'TICKET INVALID ({ticket.decode(errors="replace")})')
|
||||
|
||||
return await TunnelProtocol._read_from_uds(cfg, ticket, address[0])
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user