diff --git a/tunnel-server/src/uds_tunnel/tunnel.py b/tunnel-server/src/uds_tunnel/tunnel.py index cbf5b1b10..7ebc74d24 100644 --- a/tunnel-server/src/uds_tunnel/tunnel.py +++ b/tunnel-server/src/uds_tunnel/tunnel.py @@ -80,19 +80,25 @@ class TunnelProtocol(asyncio.Protocol): self.source = ('', 0) self.destination = ('', 0) + # If other_side is given, we are the client part (that is, the tunnel from us to remote machine) + # In this case, only do_proxy is used if other_side: self.other_side = other_side self.stats_manager = other_side.stats_manager self.counter = self.stats_manager.as_recv_counter() self.runner = self.do_proxy - else: + else: # We are the server part, that is the tunnel from client machine to us self.other_side = self self.stats_manager = stats.Stats(owner.ns) self.counter = self.stats_manager.as_sent_counter() + # We start processing command + # After command, we can process stats or do_proxy, that is the "normal" operation self.runner = self.do_command - # Set starting timeout task, se we dont get hunged on connections without data + # Set starting timeout task, se we dont get hunged on connections without data (or insufficient data) self.set_timeout(self.owner.cfg.command_timeout) + def is_server_side(self) -> bool: + return self.other_side is self def process_open(self) -> None: # Open Command has the ticket behind it @@ -261,6 +267,7 @@ class TunnelProtocol(asyncio.Protocol): def do_proxy(self, data: bytes) -> None: self.counter.add(len(data)) logger.debug('Processing proxy: %s', len(data)) + # do_proxy will only be called if other_side is set to the other side of the tunnel self.other_side.transport.write(data) # inherited from asyncio.Protocol @@ -280,6 +287,15 @@ class TunnelProtocol(asyncio.Protocol): def notify_end(self): if self.notify_ticket: + logger.info( + 'TERMINATED %s to %s, s:%s, r:%s, t:%s', + self.pretty_source(), + self.pretty_destination(), + self.stats_manager.sent, + self.stats_manager.recv, + int(self.stats_manager.end - self.stats_manager.start), + ) + # Notify end to uds, using a task becase we are not an async function asyncio.get_event_loop().create_task( TunnelProtocol.notify_end_to_uds( self.owner.cfg, self.notify_ticket, self.stats_manager @@ -290,10 +306,12 @@ class TunnelProtocol(asyncio.Protocol): self.owner.finished.set() def connection_lost(self, exc: typing.Optional[Exception]) -> None: - # Ensure close other side if any - if self.other_side is not self: + # Ensure close other side if not server_side + try: self.other_side.transport.close() - else: + except Exception: + pass + if self.other_side is self: self.stats_manager.close() self.notify_end() @@ -314,16 +332,7 @@ class TunnelProtocol(asyncio.Protocol): def close_connection(self): self.transport.close() # If destination is not set, it's a command processing (i.e. TEST or STAT) - if self.destination[0] != '': - logger.info( - 'TERMINATED %s to %s, s:%s, r:%s, t:%s', - self.pretty_source(), - self.pretty_destination(), - self.stats_manager.sent, - self.stats_manager.recv, - int(self.stats_manager.end - self.stats_manager.start), - ) - # Notify end to uds + if self.destination[0]: self.notify_end() else: logger.info('TERMINATED %s', self.pretty_source()) diff --git a/tunnel-server/src/udstunnel.py b/tunnel-server/src/udstunnel.py index 1704fa8fd..7550b597e 100755 --- a/tunnel-server/src/udstunnel.py +++ b/tunnel-server/src/udstunnel.py @@ -348,7 +348,7 @@ def main() -> None: if args.tunnel: tunnel_main(args) - if args.detailed_stats: + elif args.detailed_stats: asyncio.run(stats.getServerStats(True)) elif args.stats: asyncio.run(stats.getServerStats(False))