mirror of
https://github.com/dkmstr/openuds.git
synced 2025-03-20 06:50:23 +03:00
some minor improvements
This commit is contained in:
parent
99f52844ac
commit
beac5caa09
@ -65,7 +65,9 @@ class Proxy:
|
||||
addr = source.getpeername()
|
||||
except Exception:
|
||||
addr = 'Unknown'
|
||||
logger.exception('Proxy error from %s: %s (%s--%s)', addr, e, source, context)
|
||||
logger.exception(
|
||||
'Proxy error from %s: %s (%s--%s)', addr, e, source, context
|
||||
)
|
||||
|
||||
async def proxy(self, source: socket.socket, context: 'ssl.SSLContext') -> None:
|
||||
loop = asyncio.get_running_loop()
|
||||
@ -74,18 +76,17 @@ class Proxy:
|
||||
|
||||
# Upgrade connection to SSL, and use asyncio to handle the rest
|
||||
try:
|
||||
def factory() -> tunnel.TunnelProtocol:
|
||||
return tunnel.TunnelProtocol(self)
|
||||
# (connect accepted loop not present on AbastractEventLoop definition < 3.10), that's why we use ignore
|
||||
await loop.connect_accepted_socket( # type: ignore
|
||||
factory, source, ssl=context
|
||||
lambda: tunnel.TunnelProtocol(self), source, ssl=context
|
||||
)
|
||||
|
||||
# Wait for connection to be closed
|
||||
await self.finished.wait()
|
||||
|
||||
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass # Return on cancel
|
||||
|
||||
logger.debug('Proxy finished')
|
||||
|
||||
return
|
||||
|
@ -68,6 +68,8 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
# If there is a timeout task running
|
||||
timeout_task: typing.Optional[asyncio.Task] = None
|
||||
|
||||
is_server_side: bool
|
||||
|
||||
def __init__(
|
||||
self, owner: 'proxy.Proxy', other_side: typing.Optional['TunnelProtocol'] = None
|
||||
) -> None:
|
||||
@ -84,11 +86,13 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
# In this case, only do_proxy is used
|
||||
if other_side:
|
||||
self.other_side = other_side
|
||||
self.is_server_side = False
|
||||
self.stats_manager = other_side.stats_manager
|
||||
self.counter = self.stats_manager.as_recv_counter()
|
||||
self.runner = self.do_proxy
|
||||
else: # We are the server part, that is the tunnel from client machine to us
|
||||
self.other_side = self
|
||||
self.is_server_side = True
|
||||
self.stats_manager = stats.Stats(owner.ns)
|
||||
self.counter = self.stats_manager.as_sent_counter()
|
||||
# We start processing command
|
||||
@ -97,9 +101,6 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
# Set starting timeout task, se we dont get hunged on connections without data (or insufficient data)
|
||||
self.set_timeout(self.owner.cfg.command_timeout)
|
||||
|
||||
def is_server_side(self) -> bool:
|
||||
return self.other_side is self
|
||||
|
||||
def process_open(self) -> None:
|
||||
# Open Command has the ticket behind it
|
||||
|
||||
@ -173,35 +174,30 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
if len(self.cmd) < consts.PASSWORD_LENGTH + consts.COMMAND_LENGTH:
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info('COMMAND: %s', self.cmd[: consts.COMMAND_LENGTH].decode())
|
||||
logger.info('COMMAND: %s', self.cmd[: consts.COMMAND_LENGTH].decode())
|
||||
|
||||
# Check valid source ip
|
||||
if self.transport.get_extra_info('peername')[0] not in self.owner.cfg.allow:
|
||||
# Invalid source
|
||||
self.transport.write(consts.RESPONSE_FORBIDDEN)
|
||||
return
|
||||
# Check valid source ip
|
||||
if self.transport.get_extra_info('peername')[0] not in self.owner.cfg.allow:
|
||||
# Invalid source
|
||||
self.transport.write(consts.RESPONSE_FORBIDDEN)
|
||||
return
|
||||
|
||||
# Check password, max length is consts.PASSWORD_LENGTH
|
||||
passwd = self.cmd[consts.COMMAND_LENGTH : consts.PASSWORD_LENGTH + consts.COMMAND_LENGTH]
|
||||
# Check password, max length is consts.PASSWORD_LENGTH
|
||||
passwd = self.cmd[consts.COMMAND_LENGTH : consts.PASSWORD_LENGTH + consts.COMMAND_LENGTH]
|
||||
|
||||
# Clean up the command, only keep base part
|
||||
self.cmd = self.cmd[:4]
|
||||
# Clean up the command, only keep base part
|
||||
self.cmd = self.cmd[:4]
|
||||
|
||||
if passwd.decode(errors='ignore') != self.owner.cfg.secret:
|
||||
# Invalid password
|
||||
self.transport.write(consts.RESPONSE_FORBIDDEN)
|
||||
return
|
||||
if passwd.decode(errors='ignore') != self.owner.cfg.secret:
|
||||
# Invalid password
|
||||
self.transport.write(consts.RESPONSE_FORBIDDEN)
|
||||
return
|
||||
|
||||
data = stats.GlobalStats.get_stats(self.owner.ns)
|
||||
data = stats.GlobalStats.get_stats(self.owner.ns)
|
||||
|
||||
for v in data:
|
||||
logger.debug('SENDING %s', v)
|
||||
self.transport.write(v.encode() + b'\n')
|
||||
|
||||
logger.info('TERMINATED %s', self.pretty_source())
|
||||
finally:
|
||||
self.close_connection()
|
||||
for v in data:
|
||||
logger.debug('SENDING %s', v)
|
||||
self.transport.write(v.encode() + b'\n')
|
||||
|
||||
async def timeout(self, wait: float) -> None:
|
||||
"""Timeout can only occur while waiting for a command (or OPEN command ticket)."""
|
||||
@ -235,9 +231,10 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
if self.cmd == b'':
|
||||
logger.info('CONNECT FROM %s', self.pretty_source())
|
||||
|
||||
# Ensure we don't get a timeout
|
||||
self.clean_timeout()
|
||||
self.cmd += data
|
||||
# Ensure we don't get a timeout
|
||||
|
||||
if len(self.cmd) >= consts.COMMAND_LENGTH:
|
||||
command = self.cmd[: consts.COMMAND_LENGTH]
|
||||
try:
|
||||
@ -250,7 +247,11 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
return
|
||||
elif command in (consts.COMMAND_STAT, consts.COMMAND_INFO):
|
||||
# This is an stats requests
|
||||
self.process_stats(full=command == consts.COMMAND_STAT)
|
||||
try:
|
||||
self.process_stats(full=command == consts.COMMAND_STAT)
|
||||
except Exception as e:
|
||||
logger.error('ERROR processing stats: %s', e.args[0] if e.args else e)
|
||||
self.close_connection()
|
||||
return
|
||||
else:
|
||||
raise Exception('Invalid command')
|
||||
@ -273,7 +274,6 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
|
||||
def connection_made(self, transport: 'asyncio.transports.BaseTransport') -> None:
|
||||
logger.debug('Connection made: %s', transport.get_extra_info('peername'))
|
||||
self.main = True # This is the main connection
|
||||
|
||||
# We know for sure that the transport is a Transport.
|
||||
self.transport = typing.cast('asyncio.transports.Transport', transport)
|
||||
@ -281,7 +281,6 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
self.source = self.transport.get_extra_info('peername')
|
||||
|
||||
def data_received(self, data: bytes):
|
||||
logger.debug('Data received: %s', len(data))
|
||||
self.runner(data) # send data to current runner (command or proxy)
|
||||
|
||||
def notify_end(self):
|
||||
@ -301,12 +300,13 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
)
|
||||
)
|
||||
self.notify_ticket = b'' # Clean up so no more notifications
|
||||
elif self.other_side is self:
|
||||
# Simple log
|
||||
|
||||
if self.other_side is self: # no other side, simple connection log
|
||||
logger.info('TERMINATED %s', self.pretty_source())
|
||||
|
||||
# In any case, ensure finished is set
|
||||
self.owner.finished.set()
|
||||
if self.is_server_side:
|
||||
self.stats_manager.close()
|
||||
self.owner.finished.set()
|
||||
|
||||
def connection_lost(self, exc: typing.Optional[Exception]) -> None:
|
||||
# Ensure close other side if not server_side
|
||||
@ -314,8 +314,7 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
self.other_side.transport.close()
|
||||
except Exception:
|
||||
pass
|
||||
if self.other_side is self:
|
||||
self.stats_manager.close()
|
||||
|
||||
self.notify_end()
|
||||
|
||||
# helpers
|
||||
@ -337,8 +336,6 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
self.transport.close()
|
||||
except Exception:
|
||||
pass # Ignore errors
|
||||
# If destination is not set, it's a command processing (i.e. TEST or STAT)
|
||||
self.notify_end()
|
||||
|
||||
@staticmethod
|
||||
async def _read_from_uds(
|
||||
|
@ -206,11 +206,11 @@ def process_connection(
|
||||
data = client.recv(len(consts.HANDSHAKE_V1))
|
||||
|
||||
if data != consts.HANDSHAKE_V1:
|
||||
raise Exception('Invalid data: {} ({})'.format( addr, data.hex())) # Invalid handshake
|
||||
raise Exception('Invalid data from {}: {}'.format(addr[0], data.hex())) # Invalid handshake
|
||||
conn.send((client, addr))
|
||||
del client # Ensure socket is controlled on child process
|
||||
except Exception as e:
|
||||
logger.error('HANDSHAKE invalid (%s)', e)
|
||||
logger.error('HANDSHAKE invalid from %s: %s', addr[0], e)
|
||||
# Close Source and continue
|
||||
client.close()
|
||||
|
||||
@ -280,7 +280,7 @@ def tunnel_main(args: 'argparse.Namespace') -> None:
|
||||
while not do_stop.is_set():
|
||||
try:
|
||||
client, addr = sock.accept()
|
||||
logger.info('CONNECTION from %s', addr)
|
||||
# logger.info('CONNECTION from %s', addr)
|
||||
|
||||
# Check if we have reached the max number of connections
|
||||
# First part is checked on a thread, if HANDSHAKE is valid
|
||||
|
@ -77,18 +77,12 @@ class TestTunnel(IsolatedAsyncioTestCase):
|
||||
readed = await reader.read(1024)
|
||||
# Logger should have been called once with error
|
||||
logger_mock.error.assert_called_once()
|
||||
# last (first printed) info should have been connection info
|
||||
self.assertIn(
|
||||
'TERMINATED', logger_mock.info.call_args_list[-1][0][0]
|
||||
)
|
||||
|
||||
if len(bad_cmd) < 4:
|
||||
# Response shout have been timeout
|
||||
self.assertEqual(readed, consts.RESPONSE_ERROR_TIMEOUT)
|
||||
# And logger should have been called with timeout
|
||||
self.assertIn('TIMEOUT', logger_mock.error.call_args[0][0])
|
||||
# Logger info with connection info
|
||||
logger_mock.info.assert_called_once()
|
||||
else:
|
||||
# Response shout have been command error
|
||||
self.assertEqual(readed, consts.RESPONSE_ERROR_COMMAND)
|
||||
|
@ -36,7 +36,7 @@ from unittest import IsolatedAsyncioTestCase, mock
|
||||
|
||||
from uds_tunnel import consts
|
||||
|
||||
from .utils import tuntools, tools, conf
|
||||
from .utils import tuntools, tools
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user