mirror of
https://github.com/dkmstr/openuds.git
synced 2025-01-10 01:17:59 +03:00
upgrading and linting tunnel
This commit is contained in:
parent
ddf07eb68b
commit
4d26df9580
@ -28,6 +28,7 @@
|
|||||||
'''
|
'''
|
||||||
Author: Adolfo Gómez, dkmaster at dkmon dot com
|
Author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||||
'''
|
'''
|
||||||
|
import re
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
DEBUG = True
|
DEBUG = True
|
||||||
@ -71,3 +72,6 @@ RESPONSE_OK: typing.Final[bytes] = b'OK'
|
|||||||
|
|
||||||
# Backlog for listen socket
|
# Backlog for listen socket
|
||||||
BACKLOG = 1024
|
BACKLOG = 1024
|
||||||
|
|
||||||
|
# Regular expression for parsing ticket
|
||||||
|
TICKET_REGEX = re.compile(f'^[a-zA-Z0-9]{{{TICKET_LENGTH}}}$')
|
||||||
|
@ -78,9 +78,7 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
# If there is a timeout task running
|
# If there is a timeout task running
|
||||||
timeout_task: typing.Optional[asyncio.Task] = None
|
timeout_task: typing.Optional[asyncio.Task] = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, owner: 'proxy.Proxy') -> None:
|
||||||
self, owner: 'proxy.Proxy'
|
|
||||||
) -> None:
|
|
||||||
# If no other side is given, we are the server part
|
# If no other side is given, we are the server part
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# transport is undefined until connection_made is called
|
# transport is undefined until connection_made is called
|
||||||
@ -91,7 +89,7 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
self.destination = ('', 0)
|
self.destination = ('', 0)
|
||||||
self.tls_version = ''
|
self.tls_version = ''
|
||||||
self.tls_cipher = ''
|
self.tls_cipher = ''
|
||||||
|
|
||||||
# If other_side is given, we are the client part (that is, the tunnel from us to remote machine)
|
# If other_side is given, we are the client part (that is, the tunnel from us to remote machine)
|
||||||
# In this case, only do_proxy is used
|
# In this case, only do_proxy is used
|
||||||
self.client = None
|
self.client = None
|
||||||
@ -124,9 +122,7 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
|
|
||||||
async def open_client() -> None:
|
async def open_client() -> None:
|
||||||
try:
|
try:
|
||||||
result = await TunnelProtocol.get_ticket_from_uds(
|
result = await TunnelProtocol.get_ticket_from_uds(self.owner.cfg, ticket, self.source)
|
||||||
self.owner.cfg, ticket, self.source
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error('ERROR %s', e.args[0] if e.args else e)
|
logger.error('ERROR %s', e.args[0] if e.args else e)
|
||||||
self.transport.write(consts.RESPONSE_ERROR_TICKET)
|
self.transport.write(consts.RESPONSE_ERROR_TICKET)
|
||||||
@ -146,8 +142,7 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
try:
|
try:
|
||||||
family = (
|
family = (
|
||||||
socket.AF_INET6
|
socket.AF_INET6
|
||||||
if ':' in self.destination[0]
|
if ':' in self.destination[0] or (self.owner.cfg.ipv6 and '.' not in self.destination[0])
|
||||||
or (self.owner.cfg.ipv6 and not '.' in self.destination[0])
|
|
||||||
else socket.AF_INET
|
else socket.AF_INET
|
||||||
)
|
)
|
||||||
(_, self.client) = await loop.create_connection(
|
(_, self.client) = await loop.create_connection(
|
||||||
@ -161,7 +156,7 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
self.transport.resume_reading()
|
self.transport.resume_reading()
|
||||||
# send OK to client
|
# send OK to client
|
||||||
self.transport.write(b'OK')
|
self.transport.write(b'OK')
|
||||||
self.stats_manager.increment_connections() # Increment connections counters
|
self.stats_manager.increment_connections() # Increment connections counters
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error('Error opening connection: %s', e)
|
logger.error('Error opening connection: %s', e)
|
||||||
self.close_connection()
|
self.close_connection()
|
||||||
@ -171,7 +166,7 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
# From now, proxy connection
|
# From now, proxy connection
|
||||||
self.runner = self.do_proxy
|
self.runner = self.do_proxy
|
||||||
|
|
||||||
def process_stats(self, full: bool) -> None:
|
def process_stats(self, full: bool) -> None: # pylint: disable=unused-argument
|
||||||
# if pasword is not already received, wait for it
|
# if pasword is not already received, wait for it
|
||||||
if len(self.cmd) < consts.PASSWORD_LENGTH + consts.COMMAND_LENGTH:
|
if len(self.cmd) < consts.PASSWORD_LENGTH + consts.COMMAND_LENGTH:
|
||||||
return
|
return
|
||||||
@ -246,22 +241,22 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
try:
|
try:
|
||||||
if command == consts.COMMAND_OPEN:
|
if command == consts.COMMAND_OPEN:
|
||||||
self.process_open()
|
self.process_open()
|
||||||
elif command == consts.COMMAND_TEST:
|
return
|
||||||
|
if command == consts.COMMAND_TEST:
|
||||||
self.clean_timeout() # Stop timeout
|
self.clean_timeout() # Stop timeout
|
||||||
logger.info('COMMAND: TEST')
|
logger.info('COMMAND: TEST')
|
||||||
self.transport.write(consts.RESPONSE_OK)
|
self.transport.write(consts.RESPONSE_OK)
|
||||||
self.close_connection()
|
self.close_connection()
|
||||||
return
|
return
|
||||||
elif command in (consts.COMMAND_STAT, consts.COMMAND_INFO):
|
if command in (consts.COMMAND_STAT, consts.COMMAND_INFO):
|
||||||
# This is an stats requests
|
# This is an stats requests
|
||||||
try:
|
try:
|
||||||
self.process_stats(full=command == consts.COMMAND_STAT)
|
self.process_stats(full=command == consts.COMMAND_STAT)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error('ERROR processing stats: %s', e.args[0] if e.args else e)
|
logger.error('ERROR processing stats: %s', e.args[0] if e.args else e)
|
||||||
self.close_connection()
|
self.close_connection()
|
||||||
return
|
return
|
||||||
else:
|
raise Exception('Invalid command')
|
||||||
raise Exception('Invalid command')
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.error('ERROR from %s', self.pretty_source())
|
logger.error('ERROR from %s', self.pretty_source())
|
||||||
self.transport.write(consts.RESPONSE_ERROR_COMMAND)
|
self.transport.write(consts.RESPONSE_ERROR_COMMAND)
|
||||||
@ -298,9 +293,7 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
)
|
)
|
||||||
# Notify end to uds, using a task becase we are not an async function
|
# Notify end to uds, using a task becase we are not an async function
|
||||||
asyncio.get_event_loop().create_task(
|
asyncio.get_event_loop().create_task(
|
||||||
TunnelProtocol.notify_end_to_uds(
|
TunnelProtocol.notify_end_to_uds(self.owner.cfg, self.notify_ticket, self.stats_manager)
|
||||||
self.owner.cfg, self.notify_ticket, self.stats_manager
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
self.notify_ticket = b'' # Clean up so no more notifications
|
self.notify_ticket = b'' # Clean up so no more notifications
|
||||||
else:
|
else:
|
||||||
@ -350,7 +343,6 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
def pretty_destination(self) -> str:
|
def pretty_destination(self) -> str:
|
||||||
return TunnelProtocol.pretty_address(self.destination)
|
return TunnelProtocol.pretty_address(self.destination)
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _read_from_uds(
|
async def _read_from_uds(
|
||||||
cfg: config.ConfigurationType,
|
cfg: config.ConfigurationType,
|
||||||
@ -359,13 +351,9 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
queryParams: typing.Optional[typing.Mapping[str, str]] = None,
|
queryParams: typing.Optional[typing.Mapping[str, str]] = None,
|
||||||
) -> typing.MutableMapping[str, typing.Any]:
|
) -> typing.MutableMapping[str, typing.Any]:
|
||||||
try:
|
try:
|
||||||
url = (
|
url = cfg.uds_server + '/' + ticket.decode() + '/' + msg + '/' + cfg.uds_token
|
||||||
cfg.uds_server + '/' + ticket.decode() + '/' + msg + '/' + cfg.uds_token
|
|
||||||
)
|
|
||||||
if queryParams:
|
if queryParams:
|
||||||
url += '?' + '&'.join(
|
url += '?' + '&'.join([f'{key}={value}' for key, value in queryParams.items()])
|
||||||
[f'{key}={value}' for key, value in queryParams.items()]
|
|
||||||
)
|
|
||||||
# Set options
|
# Set options
|
||||||
options: typing.Dict[str, typing.Any] = {'timeout': cfg.uds_timeout}
|
options: typing.Dict[str, typing.Any] = {'timeout': cfg.uds_timeout}
|
||||||
if cfg.uds_verify_ssl is False:
|
if cfg.uds_verify_ssl is False:
|
||||||
@ -378,24 +366,15 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
raise Exception(await r.text())
|
raise Exception(await r.text())
|
||||||
return await r.json()
|
return await r.json()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f'TICKET COMMS ERROR: {ticket.decode()} {msg} {e!s}')
|
raise Exception(f'TICKET COMMS ERROR: {ticket.decode()} {msg} {e!s}') from e
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def get_ticket_from_uds(
|
async def get_ticket_from_uds(
|
||||||
cfg: config.ConfigurationType, ticket: bytes, address: typing.Tuple[str, int]
|
cfg: config.ConfigurationType, ticket: bytes, address: typing.Tuple[str, int]
|
||||||
) -> typing.MutableMapping[str, typing.Any]:
|
) -> typing.MutableMapping[str, typing.Any]:
|
||||||
# Sanity checks
|
# Check ticket using re
|
||||||
if len(ticket) != consts.TICKET_LENGTH:
|
if consts.TICKET_REGEX.match(ticket.decode(errors='replace')) is None:
|
||||||
raise ValueError(f'TICKET INVALID (len={len(ticket)})')
|
raise ValueError(f'TICKET INVALID ({ticket.decode(errors="replace")})')
|
||||||
|
|
||||||
for n, i in enumerate(ticket.decode(errors='ignore')):
|
|
||||||
if (
|
|
||||||
(i >= 'a' and i <= 'z')
|
|
||||||
or (i >= '0' and i <= '9')
|
|
||||||
or (i >= 'A' and i <= 'Z')
|
|
||||||
):
|
|
||||||
continue # Correctus
|
|
||||||
raise ValueError(f'TICKET INVALID (char {i} at pos {n})')
|
|
||||||
|
|
||||||
return await TunnelProtocol._read_from_uds(cfg, ticket, address[0])
|
return await TunnelProtocol._read_from_uds(cfg, ticket, address[0])
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user