Ported to asyncio uds tunnel.

This commit is contained in:
Adolfo Gómez García 2022-01-27 13:10:07 +01:00
parent 2f37caaf22
commit e043a79721
2 changed files with 60 additions and 30 deletions

View File

@ -13,7 +13,12 @@ if typing.TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ProcessType = typing.Callable[['Connection', config.ConfigurationType, 'Namespace'], typing.Coroutine[typing.Any, None, None]] ProcessType = typing.Callable[
['Connection', config.ConfigurationType, 'Namespace'],
typing.Coroutine[typing.Any, None, None],
]
NO_CPU_PERCENT = 100000.0
class Processes: class Processes:
""" """
@ -27,7 +32,9 @@ class Processes:
cfg: config.ConfigurationType cfg: config.ConfigurationType
ns: 'Namespace' ns: 'Namespace'
def __init__(self, process: ProcessType, cfg: config.ConfigurationType, ns: 'Namespace') -> None: def __init__(
self, process: ProcessType, cfg: config.ConfigurationType, ns: 'Namespace'
) -> None:
self.children = [] self.children = []
self.process = process # type: ignore self.process = process # type: ignore
self.cfg = cfg self.cfg = cfg
@ -47,7 +54,7 @@ class Processes:
self.children.append((own_conn, task, psutil.Process(task.pid))) self.children.append((own_conn, task, psutil.Process(task.pid)))
def best_child(self) -> 'Connection': def best_child(self) -> 'Connection':
best: typing.Tuple[float, 'Connection'] = (1000.0, self.children[0][0]) best: typing.Tuple[float, 'Connection'] = (NO_CPU_PERCENT, self.children[0][0])
missingProcesses: typing.List[int] = [] missingProcesses: typing.List[int] = []
for i, c in enumerate(self.children): for i, c in enumerate(self.children):
try: try:
@ -89,6 +96,10 @@ class Processes:
for i in range(len(missingProcesses)): for i in range(len(missingProcesses)):
self.add_child_pid() self.add_child_pid()
# Recheck best if all child were missing
if best[0] == NO_CPU_PERCENT:
return self.best_child()
return best[1] return best[1]
def stop(self): def stop(self):
@ -98,7 +109,12 @@ class Processes:
i[2].kill() i[2].kill()
except Exception as e: except Exception as e:
logger.info('KILLING child %s: %s', i[2], e) logger.info('KILLING child %s: %s', i[2], e)
@staticmethod @staticmethod
def runner(proc: ProcessType, conn: 'Connection', cfg: config.ConfigurationType, ns: 'Namespace') -> None: def runner(
proc: ProcessType,
conn: 'Connection',
cfg: config.ConfigurationType,
ns: 'Namespace',
) -> None:
asyncio.run(proc(conn, cfg, ns)) asyncio.run(proc(conn, cfg, ns))

View File

@ -36,18 +36,20 @@ class TunnelProtocol(asyncio.Protocol):
# counter # counter
counter: stats.StatsSingleCounter counter: stats.StatsSingleCounter
def __init__(self, owner: 'proxy.Proxy', other_side: typing.Optional['TunnelProtocol'] = None) -> None: def __init__(
self, owner: 'proxy.Proxy', other_side: typing.Optional['TunnelProtocol'] = None
) -> None:
# If no other side is given, we are the server part # If no other side is given, we are the server part
super().__init__() super().__init__()
if other_side: if other_side:
self.other_side = other_side self.other_side = other_side
self.stats_manager = other_side.stats_manager self.stats_manager = other_side.stats_manager
#self.counter = self.stats_manager.as_recv_counter() # self.counter = self.stats_manager.as_recv_counter()
self.runner = self.do_proxy self.runner = self.do_proxy
else: else:
self.other_side = self self.other_side = self
self.stats_manager = stats.Stats(owner.ns) self.stats_manager = stats.Stats(owner.ns)
#self.counter = self.stats_manager.as_sent_counter() # self.counter = self.stats_manager.as_sent_counter()
self.runner = self.do_command self.runner = self.do_command
# transport is undefined until connection_made is called # transport is undefined until connection_made is called
@ -57,14 +59,13 @@ class TunnelProtocol(asyncio.Protocol):
self.source = ('', 0) self.source = ('', 0)
self.destination = ('', 0) self.destination = ('', 0)
def process_open(self): def process_open(self):
# Open Command has the ticket behind it # Open Command has the ticket behind it
if len(self.cmd) < consts.TICKET_LENGTH + consts.COMMAND_LENGTH: if len(self.cmd) < consts.TICKET_LENGTH + consts.COMMAND_LENGTH:
return # Wait for more data return # Wait for more data to complete OPEN command
# Ticket received, now process it with UDS # Ticket received, now process it with UDS
ticket = self.cmd[consts.COMMAND_LENGTH:] ticket = self.cmd[consts.COMMAND_LENGTH :]
# clean up the command # clean up the command
self.cmd = b'' self.cmd = b''
@ -73,7 +74,9 @@ class TunnelProtocol(asyncio.Protocol):
async def open_other_side() -> None: async def open_other_side() -> None:
try: try:
result = await TunnelProtocol.getFromUds(self.owner.cfg, ticket, self.source) result = await TunnelProtocol.getFromUds(
self.owner.cfg, ticket, self.source
)
except Exception as e: except Exception as e:
logger.error('ERROR %s', e.args[0] if e.args else e) logger.error('ERROR %s', e.args[0] if e.args else e)
self.transport.write(b'ERROR_TICKET') self.transport.write(b'ERROR_TICKET')
@ -82,12 +85,18 @@ class TunnelProtocol(asyncio.Protocol):
# store for future use # store for future use
self.destination = (result['host'], int(result['port'])) self.destination = (result['host'], int(result['port']))
logger.info('OPEN TUNNEL FROM %s to %s', self.pretty_source(), self.pretty_destination()) logger.info(
'OPEN TUNNEL FROM %s to %s',
self.pretty_source(),
self.pretty_destination(),
)
try: try:
(_, protocol) = await loop.create_connection( (_, protocol) = await loop.create_connection(
lambda: TunnelProtocol(self.owner, self), self.destination[0], self.destination[1] lambda: TunnelProtocol(self.owner, self),
self.destination[0],
self.destination[1],
) )
self.other_side = typing.cast('TunnelProtocol', protocol) self.other_side = typing.cast('TunnelProtocol', protocol)
@ -106,7 +115,7 @@ class TunnelProtocol(asyncio.Protocol):
return return
try: try:
logger.info('COMMAND: %s', self.cmd[:consts.COMMAND_LENGTH].decode()) logger.info('COMMAND: %s', self.cmd[: consts.COMMAND_LENGTH].decode())
# Check valid source ip # Check valid source ip
if self.transport.get_extra_info('peername')[0] not in self.owner.cfg.allow: if self.transport.get_extra_info('peername')[0] not in self.owner.cfg.allow:
@ -115,10 +124,10 @@ class TunnelProtocol(asyncio.Protocol):
return return
# Check password # Check password
passwd = self.cmd[consts.COMMAND_LENGTH:] passwd = self.cmd[consts.COMMAND_LENGTH :]
# Clean up the command # Clean up the command, only keep base part
self.cmd = b'' self.cmd = self.cmd[:4]
if passwd.decode(errors='ignore') != self.owner.cfg.secret: if passwd.decode(errors='ignore') != self.owner.cfg.secret:
# Invalid password # Invalid password
@ -140,14 +149,13 @@ class TunnelProtocol(asyncio.Protocol):
if len(self.cmd) >= consts.COMMAND_LENGTH: if len(self.cmd) >= consts.COMMAND_LENGTH:
logger.info('CONNECT FROM %s', self.pretty_source()) logger.info('CONNECT FROM %s', self.pretty_source())
command = self.cmd[:consts.COMMAND_LENGTH] command = self.cmd[: consts.COMMAND_LENGTH]
try: try:
if command == consts.COMMAND_OPEN: if command == consts.COMMAND_OPEN:
self.process_open() self.process_open()
elif command == consts.COMMAND_TEST: elif command == consts.COMMAND_TEST:
logger.info('COMMAND: TEST') logger.info('COMMAND: TEST')
self.transport.write(b'OK') self.transport.write(b'OK')
self.close_connection() self.close_connection()
return return
elif command in (consts.COMMAND_STAT, consts.COMMAND_INFO): elif command in (consts.COMMAND_STAT, consts.COMMAND_INFO):
@ -161,16 +169,16 @@ class TunnelProtocol(asyncio.Protocol):
logger.exception('COMMAND') logger.exception('COMMAND')
logger.error('ERROR from %s', self.pretty_source()) logger.error('ERROR from %s', self.pretty_source())
self.transport.write(b'ERROR_COMMAND') self.transport.write(b'ERROR_COMMAND')
self.transport.close() # end the connection self.close_connection()
return return
# if not enough data, wait for more...
# if not enough data to process command, wait for more
def do_proxy(self, data: bytes) -> None: def do_proxy(self, data: bytes) -> None:
self.counter.add(len(data)) self.counter.add(len(data))
logger.debug('Processing proxy: %s', len(data)) logger.debug('Processing proxy: %s', len(data))
self.other_side.transport.write(data) self.other_side.transport.write(data)
# inherited from asyncio.Protocol # inherited from asyncio.Protocol
def connection_made(self, transport: 'asyncio.transports.BaseTransport') -> None: def connection_made(self, transport: 'asyncio.transports.BaseTransport') -> None:
@ -183,7 +191,7 @@ class TunnelProtocol(asyncio.Protocol):
def data_received(self, data: bytes): def data_received(self, data: bytes):
logger.debug('Data received: %s', len(data)) logger.debug('Data received: %s', len(data))
self.runner(data) self.runner(data) # send data to current runner (command or proxy)
def connection_lost(self, exc: typing.Optional[Exception]) -> None: def connection_lost(self, exc: typing.Optional[Exception]) -> None:
logger.debug('Connection closed : %s', exc) logger.debug('Connection closed : %s', exc)
@ -262,13 +270,19 @@ class TunnelProtocol(asyncio.Protocol):
continue # Correctus continue # Correctus
raise Exception(f'TICKET INVALID (char {i} at pos {n})') raise Exception(f'TICKET INVALID (char {i} at pos {n})')
return await asyncio.get_event_loop().run_in_executor(None, TunnelProtocol._getUdsUrl, cfg, ticket, address[0]) return await asyncio.get_event_loop().run_in_executor(
None, TunnelProtocol._getUdsUrl, cfg, ticket, address[0]
)
@staticmethod @staticmethod
async def notifyEndToUds( async def notifyEndToUds(
cfg: config.ConfigurationType, ticket: bytes, counter: stats.Stats cfg: config.ConfigurationType, ticket: bytes, counter: stats.Stats
) -> None: ) -> None:
TunnelProtocol._getUdsUrl( await asyncio.get_event_loop().run_in_executor(
cfg, ticket, 'stop', {'sent': str(counter.sent), 'recv': str(counter.recv)} None,
) # Ignore results TunnelProtocol._getUdsUrl,
cfg,
ticket,
'stop',
{'sent': str(counter.sent), 'recv': str(counter.recv)},
)