forked from shaba/openuds
Ported to asyncio uds tunnel.
This commit is contained in:
parent
2f37caaf22
commit
e043a79721
@ -13,7 +13,12 @@ if typing.TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ProcessType = typing.Callable[['Connection', config.ConfigurationType, 'Namespace'], typing.Coroutine[typing.Any, None, None]]
|
||||
ProcessType = typing.Callable[
|
||||
['Connection', config.ConfigurationType, 'Namespace'],
|
||||
typing.Coroutine[typing.Any, None, None],
|
||||
]
|
||||
|
||||
NO_CPU_PERCENT = 100000.0
|
||||
|
||||
class Processes:
|
||||
"""
|
||||
@ -27,7 +32,9 @@ class Processes:
|
||||
cfg: config.ConfigurationType
|
||||
ns: 'Namespace'
|
||||
|
||||
def __init__(self, process: ProcessType, cfg: config.ConfigurationType, ns: 'Namespace') -> None:
|
||||
def __init__(
|
||||
self, process: ProcessType, cfg: config.ConfigurationType, ns: 'Namespace'
|
||||
) -> None:
|
||||
self.children = []
|
||||
self.process = process # type: ignore
|
||||
self.cfg = cfg
|
||||
@ -47,7 +54,7 @@ class Processes:
|
||||
self.children.append((own_conn, task, psutil.Process(task.pid)))
|
||||
|
||||
def best_child(self) -> 'Connection':
|
||||
best: typing.Tuple[float, 'Connection'] = (1000.0, self.children[0][0])
|
||||
best: typing.Tuple[float, 'Connection'] = (NO_CPU_PERCENT, self.children[0][0])
|
||||
missingProcesses: typing.List[int] = []
|
||||
for i, c in enumerate(self.children):
|
||||
try:
|
||||
@ -89,6 +96,10 @@ class Processes:
|
||||
for i in range(len(missingProcesses)):
|
||||
self.add_child_pid()
|
||||
|
||||
# Recheck best if all child were missing
|
||||
if best[0] == NO_CPU_PERCENT:
|
||||
return self.best_child()
|
||||
|
||||
return best[1]
|
||||
|
||||
def stop(self):
|
||||
@ -98,7 +109,12 @@ class Processes:
|
||||
i[2].kill()
|
||||
except Exception as e:
|
||||
logger.info('KILLING child %s: %s', i[2], e)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def runner(proc: ProcessType, conn: 'Connection', cfg: config.ConfigurationType, ns: 'Namespace') -> None:
|
||||
def runner(
|
||||
proc: ProcessType,
|
||||
conn: 'Connection',
|
||||
cfg: config.ConfigurationType,
|
||||
ns: 'Namespace',
|
||||
) -> None:
|
||||
asyncio.run(proc(conn, cfg, ns))
|
||||
|
@ -36,18 +36,20 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
# counter
|
||||
counter: stats.StatsSingleCounter
|
||||
|
||||
def __init__(self, owner: 'proxy.Proxy', other_side: typing.Optional['TunnelProtocol'] = None) -> None:
|
||||
def __init__(
|
||||
self, owner: 'proxy.Proxy', other_side: typing.Optional['TunnelProtocol'] = None
|
||||
) -> None:
|
||||
# If no other side is given, we are the server part
|
||||
super().__init__()
|
||||
if other_side:
|
||||
self.other_side = other_side
|
||||
self.stats_manager = other_side.stats_manager
|
||||
#self.counter = self.stats_manager.as_recv_counter()
|
||||
# self.counter = self.stats_manager.as_recv_counter()
|
||||
self.runner = self.do_proxy
|
||||
else:
|
||||
self.other_side = self
|
||||
self.stats_manager = stats.Stats(owner.ns)
|
||||
#self.counter = self.stats_manager.as_sent_counter()
|
||||
# self.counter = self.stats_manager.as_sent_counter()
|
||||
self.runner = self.do_command
|
||||
|
||||
# transport is undefined until connection_made is called
|
||||
@ -57,14 +59,13 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
self.source = ('', 0)
|
||||
self.destination = ('', 0)
|
||||
|
||||
|
||||
def process_open(self):
|
||||
# Open Command has the ticket behind it
|
||||
if len(self.cmd) < consts.TICKET_LENGTH + consts.COMMAND_LENGTH:
|
||||
return # Wait for more data
|
||||
return # Wait for more data to complete OPEN command
|
||||
|
||||
# Ticket received, now process it with UDS
|
||||
ticket = self.cmd[consts.COMMAND_LENGTH:]
|
||||
ticket = self.cmd[consts.COMMAND_LENGTH :]
|
||||
|
||||
# clean up the command
|
||||
self.cmd = b''
|
||||
@ -73,7 +74,9 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
|
||||
async def open_other_side() -> None:
|
||||
try:
|
||||
result = await TunnelProtocol.getFromUds(self.owner.cfg, ticket, self.source)
|
||||
result = await TunnelProtocol.getFromUds(
|
||||
self.owner.cfg, ticket, self.source
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error('ERROR %s', e.args[0] if e.args else e)
|
||||
self.transport.write(b'ERROR_TICKET')
|
||||
@ -82,12 +85,18 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
|
||||
# store for future use
|
||||
self.destination = (result['host'], int(result['port']))
|
||||
|
||||
logger.info('OPEN TUNNEL FROM %s to %s', self.pretty_source(), self.pretty_destination())
|
||||
|
||||
logger.info(
|
||||
'OPEN TUNNEL FROM %s to %s',
|
||||
self.pretty_source(),
|
||||
self.pretty_destination(),
|
||||
)
|
||||
|
||||
try:
|
||||
(_, protocol) = await loop.create_connection(
|
||||
lambda: TunnelProtocol(self.owner, self), self.destination[0], self.destination[1]
|
||||
lambda: TunnelProtocol(self.owner, self),
|
||||
self.destination[0],
|
||||
self.destination[1],
|
||||
)
|
||||
self.other_side = typing.cast('TunnelProtocol', protocol)
|
||||
|
||||
@ -106,7 +115,7 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info('COMMAND: %s', self.cmd[:consts.COMMAND_LENGTH].decode())
|
||||
logger.info('COMMAND: %s', self.cmd[: consts.COMMAND_LENGTH].decode())
|
||||
|
||||
# Check valid source ip
|
||||
if self.transport.get_extra_info('peername')[0] not in self.owner.cfg.allow:
|
||||
@ -115,10 +124,10 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
return
|
||||
|
||||
# Check password
|
||||
passwd = self.cmd[consts.COMMAND_LENGTH:]
|
||||
passwd = self.cmd[consts.COMMAND_LENGTH :]
|
||||
|
||||
# Clean up the command
|
||||
self.cmd = b''
|
||||
# Clean up the command, only keep base part
|
||||
self.cmd = self.cmd[:4]
|
||||
|
||||
if passwd.decode(errors='ignore') != self.owner.cfg.secret:
|
||||
# Invalid password
|
||||
@ -140,14 +149,13 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
if len(self.cmd) >= consts.COMMAND_LENGTH:
|
||||
logger.info('CONNECT FROM %s', self.pretty_source())
|
||||
|
||||
command = self.cmd[:consts.COMMAND_LENGTH]
|
||||
command = self.cmd[: consts.COMMAND_LENGTH]
|
||||
try:
|
||||
if command == consts.COMMAND_OPEN:
|
||||
self.process_open()
|
||||
elif command == consts.COMMAND_TEST:
|
||||
logger.info('COMMAND: TEST')
|
||||
self.transport.write(b'OK')
|
||||
|
||||
self.close_connection()
|
||||
return
|
||||
elif command in (consts.COMMAND_STAT, consts.COMMAND_INFO):
|
||||
@ -161,16 +169,16 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
logger.exception('COMMAND')
|
||||
logger.error('ERROR from %s', self.pretty_source())
|
||||
self.transport.write(b'ERROR_COMMAND')
|
||||
self.transport.close() # end the connection
|
||||
self.close_connection()
|
||||
return
|
||||
# if not enough data, wait for more...
|
||||
|
||||
# if not enough data to process command, wait for more
|
||||
|
||||
def do_proxy(self, data: bytes) -> None:
|
||||
self.counter.add(len(data))
|
||||
logger.debug('Processing proxy: %s', len(data))
|
||||
self.other_side.transport.write(data)
|
||||
|
||||
|
||||
# inherited from asyncio.Protocol
|
||||
|
||||
def connection_made(self, transport: 'asyncio.transports.BaseTransport') -> None:
|
||||
@ -183,7 +191,7 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
|
||||
def data_received(self, data: bytes):
|
||||
logger.debug('Data received: %s', len(data))
|
||||
self.runner(data)
|
||||
self.runner(data) # send data to current runner (command or proxy)
|
||||
|
||||
def connection_lost(self, exc: typing.Optional[Exception]) -> None:
|
||||
logger.debug('Connection closed : %s', exc)
|
||||
@ -262,13 +270,19 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
continue # Correctus
|
||||
raise Exception(f'TICKET INVALID (char {i} at pos {n})')
|
||||
|
||||
return await asyncio.get_event_loop().run_in_executor(None, TunnelProtocol._getUdsUrl, cfg, ticket, address[0])
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
None, TunnelProtocol._getUdsUrl, cfg, ticket, address[0]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def notifyEndToUds(
|
||||
cfg: config.ConfigurationType, ticket: bytes, counter: stats.Stats
|
||||
) -> None:
|
||||
TunnelProtocol._getUdsUrl(
|
||||
cfg, ticket, 'stop', {'sent': str(counter.sent), 'recv': str(counter.recv)}
|
||||
) # Ignore results
|
||||
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
TunnelProtocol._getUdsUrl,
|
||||
cfg,
|
||||
ticket,
|
||||
'stop',
|
||||
{'sent': str(counter.sent), 'recv': str(counter.recv)},
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user