1
0
mirror of https://github.com/dkmstr/openuds.git synced 2025-08-24 09:49:52 +03:00

small tunnel fix

This commit is contained in:
Adolfo Gómez García
2023-01-05 17:20:12 +01:00
parent efd132bd39
commit 99f52844ac

View File

@ -266,7 +266,6 @@ class TunnelProtocol(asyncio.Protocol):
def do_proxy(self, data: bytes) -> None: def do_proxy(self, data: bytes) -> None:
self.counter.add(len(data)) self.counter.add(len(data))
logger.debug('Processing proxy: %s', len(data))
# do_proxy will only be called if other_side is set to the other side of the tunnel # do_proxy will only be called if other_side is set to the other side of the tunnel
self.other_side.transport.write(data) self.other_side.transport.write(data)
@ -302,8 +301,12 @@ class TunnelProtocol(asyncio.Protocol):
) )
) )
self.notify_ticket = b'' # Clean up so no more notifications self.notify_ticket = b'' # Clean up so no more notifications
else: # No ticket, this is "main" connection (from client to us). Notify owner that we are done elif self.other_side is self:
self.owner.finished.set() # Simple log
logger.info('TERMINATED %s', self.pretty_source())
# In any case, ensure finished is set
self.owner.finished.set()
def connection_lost(self, exc: typing.Optional[Exception]) -> None: def connection_lost(self, exc: typing.Optional[Exception]) -> None:
# Ensure close other side if not server_side # Ensure close other side if not server_side
@ -330,12 +333,12 @@ class TunnelProtocol(asyncio.Protocol):
return TunnelProtocol.pretty_address(self.destination) return TunnelProtocol.pretty_address(self.destination)
def close_connection(self): def close_connection(self):
self.transport.close() try:
self.transport.close()
except Exception:
pass # Ignore errors
# If destination is not set, it's a command processing (i.e. TEST or STAT) # If destination is not set, it's a command processing (i.e. TEST or STAT)
if self.destination[0]: self.notify_end()
self.notify_end()
else:
logger.info('TERMINATED %s', self.pretty_source())
@staticmethod @staticmethod
async def _read_from_uds( async def _read_from_uds(