1
0
mirror of https://github.com/dkmstr/openuds.git synced 2025-03-20 06:50:23 +03:00

some minor improvements

This commit is contained in:
Adolfo Gómez García 2023-01-05 18:16:30 +01:00
parent 99f52844ac
commit beac5caa09
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
5 changed files with 46 additions and 54 deletions

View File

@ -65,7 +65,9 @@ class Proxy:
addr = source.getpeername()
except Exception:
addr = 'Unknown'
logger.exception('Proxy error from %s: %s (%s--%s)', addr, e, source, context)
logger.exception(
'Proxy error from %s: %s (%s--%s)', addr, e, source, context
)
async def proxy(self, source: socket.socket, context: 'ssl.SSLContext') -> None:
loop = asyncio.get_running_loop()
@ -74,18 +76,17 @@ class Proxy:
# Upgrade connection to SSL, and use asyncio to handle the rest
try:
def factory() -> tunnel.TunnelProtocol:
return tunnel.TunnelProtocol(self)
# (connect accepted loop not present on AbastractEventLoop definition < 3.10), that's why we use ignore
await loop.connect_accepted_socket( # type: ignore
factory, source, ssl=context
lambda: tunnel.TunnelProtocol(self), source, ssl=context
)
# Wait for connection to be closed
await self.finished.wait()
except asyncio.CancelledError:
pass # Return on cancel
logger.debug('Proxy finished')
return

View File

@ -68,6 +68,8 @@ class TunnelProtocol(asyncio.Protocol):
# If there is a timeout task running
timeout_task: typing.Optional[asyncio.Task] = None
is_server_side: bool
def __init__(
self, owner: 'proxy.Proxy', other_side: typing.Optional['TunnelProtocol'] = None
) -> None:
@ -84,11 +86,13 @@ class TunnelProtocol(asyncio.Protocol):
# In this case, only do_proxy is used
if other_side:
self.other_side = other_side
self.is_server_side = False
self.stats_manager = other_side.stats_manager
self.counter = self.stats_manager.as_recv_counter()
self.runner = self.do_proxy
else: # We are the server part, that is the tunnel from client machine to us
self.other_side = self
self.is_server_side = True
self.stats_manager = stats.Stats(owner.ns)
self.counter = self.stats_manager.as_sent_counter()
# We start processing command
@ -97,9 +101,6 @@ class TunnelProtocol(asyncio.Protocol):
# Set starting timeout task, se we dont get hunged on connections without data (or insufficient data)
self.set_timeout(self.owner.cfg.command_timeout)
def is_server_side(self) -> bool:
return self.other_side is self
def process_open(self) -> None:
# Open Command has the ticket behind it
@ -173,35 +174,30 @@ class TunnelProtocol(asyncio.Protocol):
if len(self.cmd) < consts.PASSWORD_LENGTH + consts.COMMAND_LENGTH:
return
try:
logger.info('COMMAND: %s', self.cmd[: consts.COMMAND_LENGTH].decode())
logger.info('COMMAND: %s', self.cmd[: consts.COMMAND_LENGTH].decode())
# Check valid source ip
if self.transport.get_extra_info('peername')[0] not in self.owner.cfg.allow:
# Invalid source
self.transport.write(consts.RESPONSE_FORBIDDEN)
return
# Check valid source ip
if self.transport.get_extra_info('peername')[0] not in self.owner.cfg.allow:
# Invalid source
self.transport.write(consts.RESPONSE_FORBIDDEN)
return
# Check password, max length is consts.PASSWORD_LENGTH
passwd = self.cmd[consts.COMMAND_LENGTH : consts.PASSWORD_LENGTH + consts.COMMAND_LENGTH]
# Check password, max length is consts.PASSWORD_LENGTH
passwd = self.cmd[consts.COMMAND_LENGTH : consts.PASSWORD_LENGTH + consts.COMMAND_LENGTH]
# Clean up the command, only keep base part
self.cmd = self.cmd[:4]
# Clean up the command, only keep base part
self.cmd = self.cmd[:4]
if passwd.decode(errors='ignore') != self.owner.cfg.secret:
# Invalid password
self.transport.write(consts.RESPONSE_FORBIDDEN)
return
if passwd.decode(errors='ignore') != self.owner.cfg.secret:
# Invalid password
self.transport.write(consts.RESPONSE_FORBIDDEN)
return
data = stats.GlobalStats.get_stats(self.owner.ns)
data = stats.GlobalStats.get_stats(self.owner.ns)
for v in data:
logger.debug('SENDING %s', v)
self.transport.write(v.encode() + b'\n')
logger.info('TERMINATED %s', self.pretty_source())
finally:
self.close_connection()
for v in data:
logger.debug('SENDING %s', v)
self.transport.write(v.encode() + b'\n')
async def timeout(self, wait: float) -> None:
"""Timeout can only occur while waiting for a command (or OPEN command ticket)."""
@ -235,9 +231,10 @@ class TunnelProtocol(asyncio.Protocol):
if self.cmd == b'':
logger.info('CONNECT FROM %s', self.pretty_source())
# Ensure we don't get a timeout
self.clean_timeout()
self.cmd += data
# Ensure we don't get a timeout
if len(self.cmd) >= consts.COMMAND_LENGTH:
command = self.cmd[: consts.COMMAND_LENGTH]
try:
@ -250,7 +247,11 @@ class TunnelProtocol(asyncio.Protocol):
return
elif command in (consts.COMMAND_STAT, consts.COMMAND_INFO):
# This is an stats requests
self.process_stats(full=command == consts.COMMAND_STAT)
try:
self.process_stats(full=command == consts.COMMAND_STAT)
except Exception as e:
logger.error('ERROR processing stats: %s', e.args[0] if e.args else e)
self.close_connection()
return
else:
raise Exception('Invalid command')
@ -273,7 +274,6 @@ class TunnelProtocol(asyncio.Protocol):
def connection_made(self, transport: 'asyncio.transports.BaseTransport') -> None:
logger.debug('Connection made: %s', transport.get_extra_info('peername'))
self.main = True # This is the main connection
# We know for sure that the transport is a Transport.
self.transport = typing.cast('asyncio.transports.Transport', transport)
@ -281,7 +281,6 @@ class TunnelProtocol(asyncio.Protocol):
self.source = self.transport.get_extra_info('peername')
def data_received(self, data: bytes):
logger.debug('Data received: %s', len(data))
self.runner(data) # send data to current runner (command or proxy)
def notify_end(self):
@ -301,12 +300,13 @@ class TunnelProtocol(asyncio.Protocol):
)
)
self.notify_ticket = b'' # Clean up so no more notifications
elif self.other_side is self:
# Simple log
if self.other_side is self: # no other side, simple connection log
logger.info('TERMINATED %s', self.pretty_source())
# In any case, ensure finished is set
self.owner.finished.set()
if self.is_server_side:
self.stats_manager.close()
self.owner.finished.set()
def connection_lost(self, exc: typing.Optional[Exception]) -> None:
# Ensure close other side if not server_side
@ -314,8 +314,7 @@ class TunnelProtocol(asyncio.Protocol):
self.other_side.transport.close()
except Exception:
pass
if self.other_side is self:
self.stats_manager.close()
self.notify_end()
# helpers
@ -337,8 +336,6 @@ class TunnelProtocol(asyncio.Protocol):
self.transport.close()
except Exception:
pass # Ignore errors
# If destination is not set, it's a command processing (i.e. TEST or STAT)
self.notify_end()
@staticmethod
async def _read_from_uds(

View File

@ -206,11 +206,11 @@ def process_connection(
data = client.recv(len(consts.HANDSHAKE_V1))
if data != consts.HANDSHAKE_V1:
raise Exception('Invalid data: {} ({})'.format( addr, data.hex())) # Invalid handshake
raise Exception('Invalid data from {}: {}'.format(addr[0], data.hex())) # Invalid handshake
conn.send((client, addr))
del client # Ensure socket is controlled on child process
except Exception as e:
logger.error('HANDSHAKE invalid (%s)', e)
logger.error('HANDSHAKE invalid from %s: %s', addr[0], e)
# Close Source and continue
client.close()
@ -280,7 +280,7 @@ def tunnel_main(args: 'argparse.Namespace') -> None:
while not do_stop.is_set():
try:
client, addr = sock.accept()
logger.info('CONNECTION from %s', addr)
# logger.info('CONNECTION from %s', addr)
# Check if we have reached the max number of connections
# First part is checked on a thread, if HANDSHAKE is valid

View File

@ -77,18 +77,12 @@ class TestTunnel(IsolatedAsyncioTestCase):
readed = await reader.read(1024)
# Logger should have been called once with error
logger_mock.error.assert_called_once()
# last (first printed) info should have been connection info
self.assertIn(
'TERMINATED', logger_mock.info.call_args_list[-1][0][0]
)
if len(bad_cmd) < 4:
# Response shout have been timeout
self.assertEqual(readed, consts.RESPONSE_ERROR_TIMEOUT)
# And logger should have been called with timeout
self.assertIn('TIMEOUT', logger_mock.error.call_args[0][0])
# Logger info with connection info
logger_mock.info.assert_called_once()
else:
# Response shout have been command error
self.assertEqual(readed, consts.RESPONSE_ERROR_COMMAND)

View File

@ -36,7 +36,7 @@ from unittest import IsolatedAsyncioTestCase, mock
from uds_tunnel import consts
from .utils import tuntools, tools, conf
from .utils import tuntools, tools
logger = logging.getLogger(__name__)