diff --git a/tunnel-server/src/uds_tunnel/proxy.py b/tunnel-server/src/uds_tunnel/proxy.py index 1baee4450..93abf28bd 100644 --- a/tunnel-server/src/uds_tunnel/proxy.py +++ b/tunnel-server/src/uds_tunnel/proxy.py @@ -65,7 +65,9 @@ class Proxy: addr = source.getpeername() except Exception: addr = 'Unknown' - logger.exception('Proxy error from %s: %s (%s--%s)', addr, e, source, context) + logger.exception( + 'Proxy error from %s: %s (%s--%s)', addr, e, source, context + ) async def proxy(self, source: socket.socket, context: 'ssl.SSLContext') -> None: loop = asyncio.get_running_loop() @@ -74,18 +76,17 @@ class Proxy: # Upgrade connection to SSL, and use asyncio to handle the rest try: - def factory() -> tunnel.TunnelProtocol: - return tunnel.TunnelProtocol(self) # (connect accepted loop not present on AbastractEventLoop definition < 3.10), that's why we use ignore await loop.connect_accepted_socket( # type: ignore - factory, source, ssl=context + lambda: tunnel.TunnelProtocol(self), source, ssl=context ) # Wait for connection to be closed await self.finished.wait() - - + except asyncio.CancelledError: pass # Return on cancel + logger.debug('Proxy finished') + return diff --git a/tunnel-server/src/uds_tunnel/tunnel.py b/tunnel-server/src/uds_tunnel/tunnel.py index 55cb05019..e4cad577a 100644 --- a/tunnel-server/src/uds_tunnel/tunnel.py +++ b/tunnel-server/src/uds_tunnel/tunnel.py @@ -68,6 +68,8 @@ class TunnelProtocol(asyncio.Protocol): # If there is a timeout task running timeout_task: typing.Optional[asyncio.Task] = None + is_server_side: bool + def __init__( self, owner: 'proxy.Proxy', other_side: typing.Optional['TunnelProtocol'] = None ) -> None: @@ -84,11 +86,13 @@ class TunnelProtocol(asyncio.Protocol): # In this case, only do_proxy is used if other_side: self.other_side = other_side + self.is_server_side = False self.stats_manager = other_side.stats_manager self.counter = self.stats_manager.as_recv_counter() self.runner = self.do_proxy else: # We are the server part, that is the tunnel from client machine to us self.other_side = self + self.is_server_side = True self.stats_manager = stats.Stats(owner.ns) self.counter = self.stats_manager.as_sent_counter() # We start processing command @@ -97,9 +101,6 @@ class TunnelProtocol(asyncio.Protocol): # Set starting timeout task, se we dont get hunged on connections without data (or insufficient data) self.set_timeout(self.owner.cfg.command_timeout) - def is_server_side(self) -> bool: - return self.other_side is self - def process_open(self) -> None: # Open Command has the ticket behind it @@ -173,35 +174,30 @@ class TunnelProtocol(asyncio.Protocol): if len(self.cmd) < consts.PASSWORD_LENGTH + consts.COMMAND_LENGTH: return - try: - logger.info('COMMAND: %s', self.cmd[: consts.COMMAND_LENGTH].decode()) + logger.info('COMMAND: %s', self.cmd[: consts.COMMAND_LENGTH].decode()) - # Check valid source ip - if self.transport.get_extra_info('peername')[0] not in self.owner.cfg.allow: - # Invalid source - self.transport.write(consts.RESPONSE_FORBIDDEN) - return + # Check valid source ip + if self.transport.get_extra_info('peername')[0] not in self.owner.cfg.allow: + # Invalid source + self.transport.write(consts.RESPONSE_FORBIDDEN) + return - # Check password, max length is consts.PASSWORD_LENGTH - passwd = self.cmd[consts.COMMAND_LENGTH : consts.PASSWORD_LENGTH + consts.COMMAND_LENGTH] + # Check password, max length is consts.PASSWORD_LENGTH + passwd = self.cmd[consts.COMMAND_LENGTH : consts.PASSWORD_LENGTH + consts.COMMAND_LENGTH] - # Clean up the command, only keep base part - self.cmd = self.cmd[:4] + # Clean up the command, only keep base part + self.cmd = self.cmd[:4] - if passwd.decode(errors='ignore') != self.owner.cfg.secret: - # Invalid password - self.transport.write(consts.RESPONSE_FORBIDDEN) - return + if passwd.decode(errors='ignore') != self.owner.cfg.secret: + # Invalid password + self.transport.write(consts.RESPONSE_FORBIDDEN) + return - data = stats.GlobalStats.get_stats(self.owner.ns) + data = stats.GlobalStats.get_stats(self.owner.ns) - for v in data: - logger.debug('SENDING %s', v) - self.transport.write(v.encode() + b'\n') - - logger.info('TERMINATED %s', self.pretty_source()) - finally: - self.close_connection() + for v in data: + logger.debug('SENDING %s', v) + self.transport.write(v.encode() + b'\n') async def timeout(self, wait: float) -> None: """Timeout can only occur while waiting for a command (or OPEN command ticket).""" @@ -235,9 +231,10 @@ class TunnelProtocol(asyncio.Protocol): if self.cmd == b'': logger.info('CONNECT FROM %s', self.pretty_source()) + # Ensure we don't get a timeout self.clean_timeout() self.cmd += data - # Ensure we don't get a timeout + if len(self.cmd) >= consts.COMMAND_LENGTH: command = self.cmd[: consts.COMMAND_LENGTH] try: @@ -250,7 +247,11 @@ class TunnelProtocol(asyncio.Protocol): return elif command in (consts.COMMAND_STAT, consts.COMMAND_INFO): # This is an stats requests - self.process_stats(full=command == consts.COMMAND_STAT) + try: + self.process_stats(full=command == consts.COMMAND_STAT) + except Exception as e: + logger.error('ERROR processing stats: %s', e.args[0] if e.args else e) + self.close_connection() return else: raise Exception('Invalid command') @@ -273,7 +274,6 @@ class TunnelProtocol(asyncio.Protocol): def connection_made(self, transport: 'asyncio.transports.BaseTransport') -> None: logger.debug('Connection made: %s', transport.get_extra_info('peername')) - self.main = True # This is the main connection # We know for sure that the transport is a Transport. self.transport = typing.cast('asyncio.transports.Transport', transport) @@ -281,7 +281,6 @@ class TunnelProtocol(asyncio.Protocol): self.source = self.transport.get_extra_info('peername') def data_received(self, data: bytes): - logger.debug('Data received: %s', len(data)) self.runner(data) # send data to current runner (command or proxy) def notify_end(self): @@ -301,12 +300,13 @@ class TunnelProtocol(asyncio.Protocol): ) ) self.notify_ticket = b'' # Clean up so no more notifications - elif self.other_side is self: - # Simple log + + if self.other_side is self: # no other side, simple connection log logger.info('TERMINATED %s', self.pretty_source()) - # In any case, ensure finished is set - self.owner.finished.set() + if self.is_server_side: + self.stats_manager.close() + self.owner.finished.set() def connection_lost(self, exc: typing.Optional[Exception]) -> None: # Ensure close other side if not server_side @@ -314,8 +314,7 @@ class TunnelProtocol(asyncio.Protocol): self.other_side.transport.close() except Exception: pass - if self.other_side is self: - self.stats_manager.close() + self.notify_end() # helpers @@ -337,8 +336,6 @@ class TunnelProtocol(asyncio.Protocol): self.transport.close() except Exception: pass # Ignore errors - # If destination is not set, it's a command processing (i.e. TEST or STAT) - self.notify_end() @staticmethod async def _read_from_uds( diff --git a/tunnel-server/src/udstunnel.py b/tunnel-server/src/udstunnel.py index 7550b597e..2437a53b1 100755 --- a/tunnel-server/src/udstunnel.py +++ b/tunnel-server/src/udstunnel.py @@ -206,11 +206,11 @@ def process_connection( data = client.recv(len(consts.HANDSHAKE_V1)) if data != consts.HANDSHAKE_V1: - raise Exception('Invalid data: {} ({})'.format( addr, data.hex())) # Invalid handshake + raise Exception('Invalid data from {}: {}'.format(addr[0], data.hex())) # Invalid handshake conn.send((client, addr)) del client # Ensure socket is controlled on child process except Exception as e: - logger.error('HANDSHAKE invalid (%s)', e) + logger.error('HANDSHAKE invalid from %s: %s', addr[0], e) # Close Source and continue client.close() @@ -280,7 +280,7 @@ def tunnel_main(args: 'argparse.Namespace') -> None: while not do_stop.is_set(): try: client, addr = sock.accept() - logger.info('CONNECTION from %s', addr) + # logger.info('CONNECTION from %s', addr) # Check if we have reached the max number of connections # First part is checked on a thread, if HANDSHAKE is valid diff --git a/tunnel-server/test/test_tunnel.py b/tunnel-server/test/test_tunnel.py index fe494ff13..4b8186459 100644 --- a/tunnel-server/test/test_tunnel.py +++ b/tunnel-server/test/test_tunnel.py @@ -77,18 +77,12 @@ class TestTunnel(IsolatedAsyncioTestCase): readed = await reader.read(1024) # Logger should have been called once with error logger_mock.error.assert_called_once() - # last (first printed) info should have been connection info - self.assertIn( - 'TERMINATED', logger_mock.info.call_args_list[-1][0][0] - ) if len(bad_cmd) < 4: # Response shout have been timeout self.assertEqual(readed, consts.RESPONSE_ERROR_TIMEOUT) # And logger should have been called with timeout self.assertIn('TIMEOUT', logger_mock.error.call_args[0][0]) - # Logger info with connection info - logger_mock.info.assert_called_once() else: # Response shout have been command error self.assertEqual(readed, consts.RESPONSE_ERROR_COMMAND) diff --git a/tunnel-server/test/test_udstunnel.py b/tunnel-server/test/test_udstunnel.py index d5ef0ce15..9c6ca0b85 100644 --- a/tunnel-server/test/test_udstunnel.py +++ b/tunnel-server/test/test_udstunnel.py @@ -36,7 +36,7 @@ from unittest import IsolatedAsyncioTestCase, mock from uds_tunnel import consts -from .utils import tuntools, tools, conf +from .utils import tuntools, tools logger = logging.getLogger(__name__)