forked from shaba/openuds
Ported to asyncio uds tunnel.
This commit is contained in:
parent
2f37caaf22
commit
e043a79721
@ -13,7 +13,12 @@ if typing.TYPE_CHECKING:
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ProcessType = typing.Callable[['Connection', config.ConfigurationType, 'Namespace'], typing.Coroutine[typing.Any, None, None]]
|
ProcessType = typing.Callable[
|
||||||
|
['Connection', config.ConfigurationType, 'Namespace'],
|
||||||
|
typing.Coroutine[typing.Any, None, None],
|
||||||
|
]
|
||||||
|
|
||||||
|
NO_CPU_PERCENT = 100000.0
|
||||||
|
|
||||||
class Processes:
|
class Processes:
|
||||||
"""
|
"""
|
||||||
@ -27,7 +32,9 @@ class Processes:
|
|||||||
cfg: config.ConfigurationType
|
cfg: config.ConfigurationType
|
||||||
ns: 'Namespace'
|
ns: 'Namespace'
|
||||||
|
|
||||||
def __init__(self, process: ProcessType, cfg: config.ConfigurationType, ns: 'Namespace') -> None:
|
def __init__(
|
||||||
|
self, process: ProcessType, cfg: config.ConfigurationType, ns: 'Namespace'
|
||||||
|
) -> None:
|
||||||
self.children = []
|
self.children = []
|
||||||
self.process = process # type: ignore
|
self.process = process # type: ignore
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
@ -47,7 +54,7 @@ class Processes:
|
|||||||
self.children.append((own_conn, task, psutil.Process(task.pid)))
|
self.children.append((own_conn, task, psutil.Process(task.pid)))
|
||||||
|
|
||||||
def best_child(self) -> 'Connection':
|
def best_child(self) -> 'Connection':
|
||||||
best: typing.Tuple[float, 'Connection'] = (1000.0, self.children[0][0])
|
best: typing.Tuple[float, 'Connection'] = (NO_CPU_PERCENT, self.children[0][0])
|
||||||
missingProcesses: typing.List[int] = []
|
missingProcesses: typing.List[int] = []
|
||||||
for i, c in enumerate(self.children):
|
for i, c in enumerate(self.children):
|
||||||
try:
|
try:
|
||||||
@ -89,6 +96,10 @@ class Processes:
|
|||||||
for i in range(len(missingProcesses)):
|
for i in range(len(missingProcesses)):
|
||||||
self.add_child_pid()
|
self.add_child_pid()
|
||||||
|
|
||||||
|
# Recheck best if all child were missing
|
||||||
|
if best[0] == NO_CPU_PERCENT:
|
||||||
|
return self.best_child()
|
||||||
|
|
||||||
return best[1]
|
return best[1]
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
@ -98,7 +109,12 @@ class Processes:
|
|||||||
i[2].kill()
|
i[2].kill()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info('KILLING child %s: %s', i[2], e)
|
logger.info('KILLING child %s: %s', i[2], e)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def runner(proc: ProcessType, conn: 'Connection', cfg: config.ConfigurationType, ns: 'Namespace') -> None:
|
def runner(
|
||||||
|
proc: ProcessType,
|
||||||
|
conn: 'Connection',
|
||||||
|
cfg: config.ConfigurationType,
|
||||||
|
ns: 'Namespace',
|
||||||
|
) -> None:
|
||||||
asyncio.run(proc(conn, cfg, ns))
|
asyncio.run(proc(conn, cfg, ns))
|
||||||
|
@ -36,18 +36,20 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
# counter
|
# counter
|
||||||
counter: stats.StatsSingleCounter
|
counter: stats.StatsSingleCounter
|
||||||
|
|
||||||
def __init__(self, owner: 'proxy.Proxy', other_side: typing.Optional['TunnelProtocol'] = None) -> None:
|
def __init__(
|
||||||
|
self, owner: 'proxy.Proxy', other_side: typing.Optional['TunnelProtocol'] = None
|
||||||
|
) -> None:
|
||||||
# If no other side is given, we are the server part
|
# If no other side is given, we are the server part
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if other_side:
|
if other_side:
|
||||||
self.other_side = other_side
|
self.other_side = other_side
|
||||||
self.stats_manager = other_side.stats_manager
|
self.stats_manager = other_side.stats_manager
|
||||||
#self.counter = self.stats_manager.as_recv_counter()
|
# self.counter = self.stats_manager.as_recv_counter()
|
||||||
self.runner = self.do_proxy
|
self.runner = self.do_proxy
|
||||||
else:
|
else:
|
||||||
self.other_side = self
|
self.other_side = self
|
||||||
self.stats_manager = stats.Stats(owner.ns)
|
self.stats_manager = stats.Stats(owner.ns)
|
||||||
#self.counter = self.stats_manager.as_sent_counter()
|
# self.counter = self.stats_manager.as_sent_counter()
|
||||||
self.runner = self.do_command
|
self.runner = self.do_command
|
||||||
|
|
||||||
# transport is undefined until connection_made is called
|
# transport is undefined until connection_made is called
|
||||||
@ -57,14 +59,13 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
self.source = ('', 0)
|
self.source = ('', 0)
|
||||||
self.destination = ('', 0)
|
self.destination = ('', 0)
|
||||||
|
|
||||||
|
|
||||||
def process_open(self):
|
def process_open(self):
|
||||||
# Open Command has the ticket behind it
|
# Open Command has the ticket behind it
|
||||||
if len(self.cmd) < consts.TICKET_LENGTH + consts.COMMAND_LENGTH:
|
if len(self.cmd) < consts.TICKET_LENGTH + consts.COMMAND_LENGTH:
|
||||||
return # Wait for more data
|
return # Wait for more data to complete OPEN command
|
||||||
|
|
||||||
# Ticket received, now process it with UDS
|
# Ticket received, now process it with UDS
|
||||||
ticket = self.cmd[consts.COMMAND_LENGTH:]
|
ticket = self.cmd[consts.COMMAND_LENGTH :]
|
||||||
|
|
||||||
# clean up the command
|
# clean up the command
|
||||||
self.cmd = b''
|
self.cmd = b''
|
||||||
@ -73,7 +74,9 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
|
|
||||||
async def open_other_side() -> None:
|
async def open_other_side() -> None:
|
||||||
try:
|
try:
|
||||||
result = await TunnelProtocol.getFromUds(self.owner.cfg, ticket, self.source)
|
result = await TunnelProtocol.getFromUds(
|
||||||
|
self.owner.cfg, ticket, self.source
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error('ERROR %s', e.args[0] if e.args else e)
|
logger.error('ERROR %s', e.args[0] if e.args else e)
|
||||||
self.transport.write(b'ERROR_TICKET')
|
self.transport.write(b'ERROR_TICKET')
|
||||||
@ -82,12 +85,18 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
|
|
||||||
# store for future use
|
# store for future use
|
||||||
self.destination = (result['host'], int(result['port']))
|
self.destination = (result['host'], int(result['port']))
|
||||||
|
|
||||||
logger.info('OPEN TUNNEL FROM %s to %s', self.pretty_source(), self.pretty_destination())
|
logger.info(
|
||||||
|
'OPEN TUNNEL FROM %s to %s',
|
||||||
|
self.pretty_source(),
|
||||||
|
self.pretty_destination(),
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
(_, protocol) = await loop.create_connection(
|
(_, protocol) = await loop.create_connection(
|
||||||
lambda: TunnelProtocol(self.owner, self), self.destination[0], self.destination[1]
|
lambda: TunnelProtocol(self.owner, self),
|
||||||
|
self.destination[0],
|
||||||
|
self.destination[1],
|
||||||
)
|
)
|
||||||
self.other_side = typing.cast('TunnelProtocol', protocol)
|
self.other_side = typing.cast('TunnelProtocol', protocol)
|
||||||
|
|
||||||
@ -106,7 +115,7 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info('COMMAND: %s', self.cmd[:consts.COMMAND_LENGTH].decode())
|
logger.info('COMMAND: %s', self.cmd[: consts.COMMAND_LENGTH].decode())
|
||||||
|
|
||||||
# Check valid source ip
|
# Check valid source ip
|
||||||
if self.transport.get_extra_info('peername')[0] not in self.owner.cfg.allow:
|
if self.transport.get_extra_info('peername')[0] not in self.owner.cfg.allow:
|
||||||
@ -115,10 +124,10 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Check password
|
# Check password
|
||||||
passwd = self.cmd[consts.COMMAND_LENGTH:]
|
passwd = self.cmd[consts.COMMAND_LENGTH :]
|
||||||
|
|
||||||
# Clean up the command
|
# Clean up the command, only keep base part
|
||||||
self.cmd = b''
|
self.cmd = self.cmd[:4]
|
||||||
|
|
||||||
if passwd.decode(errors='ignore') != self.owner.cfg.secret:
|
if passwd.decode(errors='ignore') != self.owner.cfg.secret:
|
||||||
# Invalid password
|
# Invalid password
|
||||||
@ -140,14 +149,13 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
if len(self.cmd) >= consts.COMMAND_LENGTH:
|
if len(self.cmd) >= consts.COMMAND_LENGTH:
|
||||||
logger.info('CONNECT FROM %s', self.pretty_source())
|
logger.info('CONNECT FROM %s', self.pretty_source())
|
||||||
|
|
||||||
command = self.cmd[:consts.COMMAND_LENGTH]
|
command = self.cmd[: consts.COMMAND_LENGTH]
|
||||||
try:
|
try:
|
||||||
if command == consts.COMMAND_OPEN:
|
if command == consts.COMMAND_OPEN:
|
||||||
self.process_open()
|
self.process_open()
|
||||||
elif command == consts.COMMAND_TEST:
|
elif command == consts.COMMAND_TEST:
|
||||||
logger.info('COMMAND: TEST')
|
logger.info('COMMAND: TEST')
|
||||||
self.transport.write(b'OK')
|
self.transport.write(b'OK')
|
||||||
|
|
||||||
self.close_connection()
|
self.close_connection()
|
||||||
return
|
return
|
||||||
elif command in (consts.COMMAND_STAT, consts.COMMAND_INFO):
|
elif command in (consts.COMMAND_STAT, consts.COMMAND_INFO):
|
||||||
@ -161,16 +169,16 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
logger.exception('COMMAND')
|
logger.exception('COMMAND')
|
||||||
logger.error('ERROR from %s', self.pretty_source())
|
logger.error('ERROR from %s', self.pretty_source())
|
||||||
self.transport.write(b'ERROR_COMMAND')
|
self.transport.write(b'ERROR_COMMAND')
|
||||||
self.transport.close() # end the connection
|
self.close_connection()
|
||||||
return
|
return
|
||||||
# if not enough data, wait for more...
|
|
||||||
|
# if not enough data to process command, wait for more
|
||||||
|
|
||||||
def do_proxy(self, data: bytes) -> None:
|
def do_proxy(self, data: bytes) -> None:
|
||||||
self.counter.add(len(data))
|
self.counter.add(len(data))
|
||||||
logger.debug('Processing proxy: %s', len(data))
|
logger.debug('Processing proxy: %s', len(data))
|
||||||
self.other_side.transport.write(data)
|
self.other_side.transport.write(data)
|
||||||
|
|
||||||
|
|
||||||
# inherited from asyncio.Protocol
|
# inherited from asyncio.Protocol
|
||||||
|
|
||||||
def connection_made(self, transport: 'asyncio.transports.BaseTransport') -> None:
|
def connection_made(self, transport: 'asyncio.transports.BaseTransport') -> None:
|
||||||
@ -183,7 +191,7 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
|
|
||||||
def data_received(self, data: bytes):
|
def data_received(self, data: bytes):
|
||||||
logger.debug('Data received: %s', len(data))
|
logger.debug('Data received: %s', len(data))
|
||||||
self.runner(data)
|
self.runner(data) # send data to current runner (command or proxy)
|
||||||
|
|
||||||
def connection_lost(self, exc: typing.Optional[Exception]) -> None:
|
def connection_lost(self, exc: typing.Optional[Exception]) -> None:
|
||||||
logger.debug('Connection closed : %s', exc)
|
logger.debug('Connection closed : %s', exc)
|
||||||
@ -262,13 +270,19 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
continue # Correctus
|
continue # Correctus
|
||||||
raise Exception(f'TICKET INVALID (char {i} at pos {n})')
|
raise Exception(f'TICKET INVALID (char {i} at pos {n})')
|
||||||
|
|
||||||
return await asyncio.get_event_loop().run_in_executor(None, TunnelProtocol._getUdsUrl, cfg, ticket, address[0])
|
return await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None, TunnelProtocol._getUdsUrl, cfg, ticket, address[0]
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def notifyEndToUds(
|
async def notifyEndToUds(
|
||||||
cfg: config.ConfigurationType, ticket: bytes, counter: stats.Stats
|
cfg: config.ConfigurationType, ticket: bytes, counter: stats.Stats
|
||||||
) -> None:
|
) -> None:
|
||||||
TunnelProtocol._getUdsUrl(
|
await asyncio.get_event_loop().run_in_executor(
|
||||||
cfg, ticket, 'stop', {'sent': str(counter.sent), 'recv': str(counter.recv)}
|
None,
|
||||||
) # Ignore results
|
TunnelProtocol._getUdsUrl,
|
||||||
|
cfg,
|
||||||
|
ticket,
|
||||||
|
'stop',
|
||||||
|
{'sent': str(counter.sent), 'recv': str(counter.recv)},
|
||||||
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user