forked from shaba/openuds
Removing curio from tunneler, so we do not have an unneeded dependency
This commit is contained in:
parent
7d8ae689b5
commit
77b6eff8e4
@ -58,3 +58,4 @@ COMMAND_OPEN = b'OPEN'
|
||||
COMMAND_TEST = b'TEST'
|
||||
COMMAND_STAT = b'STAT' # full stats
|
||||
COMMAND_INFO = b'INFO' # Basic stats, currently same as FULL
|
||||
|
||||
|
@ -39,8 +39,8 @@ class Processes:
|
||||
def add_child_pid(self):
|
||||
own_conn, child_conn = multiprocessing.Pipe()
|
||||
task = multiprocessing.Process(
|
||||
target=asyncio.run,
|
||||
args=(self.process(child_conn, self.cfg, self.ns),)
|
||||
target=Processes.runner,
|
||||
args=(self.process, child_conn, self.cfg, self.ns),
|
||||
)
|
||||
task.start()
|
||||
logger.debug('ADD CHILD PID: %s', task.pid)
|
||||
@ -98,4 +98,7 @@ class Processes:
|
||||
i[2].kill()
|
||||
except Exception as e:
|
||||
logger.info('KILLING child %s: %s', i[2], e)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def runner(proc: ProcessType, conn: 'Connection', cfg: config.ConfigurationType, ns: 'Namespace') -> None:
|
||||
asyncio.run(proc(conn, cfg, ns))
|
||||
|
@ -33,11 +33,8 @@ import socket
|
||||
import logging
|
||||
import typing
|
||||
|
||||
import requests
|
||||
|
||||
from . import config
|
||||
from . import stats
|
||||
from . import consts
|
||||
from . import tunnel
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from multiprocessing.managers import Namespace
|
||||
@ -54,196 +51,109 @@ class Proxy:
|
||||
self.cfg = cfg
|
||||
self.ns = ns
|
||||
|
||||
@staticmethod
|
||||
def _getUdsUrl(
|
||||
cfg: config.ConfigurationType,
|
||||
ticket: bytes,
|
||||
msg: str,
|
||||
queryParams: typing.Mapping[str, str] = None,
|
||||
) -> typing.MutableMapping[str, typing.Any]:
|
||||
try:
|
||||
url = (
|
||||
cfg.uds_server + '/' + ticket.decode() + '/' + msg + '/' + cfg.uds_token
|
||||
)
|
||||
if queryParams:
|
||||
url += '?' + '&'.join(
|
||||
[f'{key}={value}' for key, value in queryParams.items()]
|
||||
)
|
||||
r = requests.get(
|
||||
url,
|
||||
headers={
|
||||
'content-type': 'application/json',
|
||||
'User-Agent': f'UDSTunnel-{consts.VERSION}',
|
||||
},
|
||||
)
|
||||
if not r.ok:
|
||||
raise Exception(r.content or 'Invalid Ticket (timed out)')
|
||||
|
||||
return r.json()
|
||||
except Exception as e:
|
||||
raise Exception(f'TICKET COMMS ERROR: {ticket.decode()} {msg} {e!s}')
|
||||
|
||||
@staticmethod
|
||||
def getFromUds(
|
||||
cfg: config.ConfigurationType, ticket: bytes, address: typing.Tuple[str, int]
|
||||
) -> typing.MutableMapping[str, typing.Any]:
|
||||
# Sanity checks
|
||||
if len(ticket) != consts.TICKET_LENGTH:
|
||||
raise Exception(f'TICKET INVALID (len={len(ticket)})')
|
||||
|
||||
for n, i in enumerate(ticket.decode(errors='ignore')):
|
||||
if (
|
||||
(i >= 'a' and i <= 'z')
|
||||
or (i >= '0' and i <= '9')
|
||||
or (i >= 'A' and i <= 'Z')
|
||||
):
|
||||
continue # Correctus
|
||||
raise Exception(f'TICKET INVALID (char {i} at pos {n})')
|
||||
|
||||
return Proxy._getUdsUrl(cfg, ticket, address[0])
|
||||
|
||||
@staticmethod
|
||||
def notifyEndToUds(
|
||||
cfg: config.ConfigurationType, ticket: bytes, counter: stats.Stats
|
||||
) -> None:
|
||||
Proxy._getUdsUrl(
|
||||
cfg, ticket, 'stop', {'sent': str(counter.sent), 'recv': str(counter.recv)}
|
||||
) # Ignore results
|
||||
|
||||
# @staticmethod
|
||||
# async def doProxy(
|
||||
# source: 'curio.io.Socket',
|
||||
# destination: 'curio.io.Socket',
|
||||
# counter: stats.StatsSingleCounter,
|
||||
# ) -> None:
|
||||
# try:
|
||||
# while True:
|
||||
# data = await source.recv(consts.BUFFER_SIZE)
|
||||
# if not data:
|
||||
# break
|
||||
# await destination.sendall(data)
|
||||
# counter.add(len(data))
|
||||
# except Exception:
|
||||
# # Connection broken, same result as closed for us
|
||||
# # We must notice that i'ts easy that when closing one part of the tunnel,
|
||||
# # the other can break (due to some internal data), that's why even log is removed
|
||||
# # logger.info('CONNECTION LOST FROM %s to %s', source.getsockname(), destination.getpeername())
|
||||
# pass
|
||||
|
||||
# Method responsible of proxying requests
|
||||
async def __call__(self, source: socket.socket, address: typing.Tuple[str, int], context: 'ssl.SSLContext') -> None:
|
||||
await self.proxy(source, address, context)
|
||||
async def __call__(self, source: socket.socket, context: 'ssl.SSLContext') -> None:
|
||||
await self.proxy(source, context)
|
||||
|
||||
async def stats(self, full: bool, source, address: typing.Tuple[str, int]) -> None:
|
||||
# Check valid source ip
|
||||
if address[0] not in self.cfg.allow:
|
||||
# Invalid source
|
||||
await source.sendall(b'FORBIDDEN')
|
||||
return
|
||||
|
||||
# Check password
|
||||
passwd = await source.recv(consts.PASSWORD_LENGTH)
|
||||
if passwd.decode(errors='ignore') != self.cfg.secret:
|
||||
# Invalid password
|
||||
await source.sendall(b'FORBIDDEN')
|
||||
return
|
||||
|
||||
data = stats.GlobalStats.get_stats(self.ns)
|
||||
|
||||
for v in data:
|
||||
logger.debug('SENDING %s', v)
|
||||
await source.sendall(v.encode() + b'\n')
|
||||
|
||||
async def proxy(self, source: socket.socket, address: typing.Tuple[str, int], context: 'ssl.SSLContext') -> None:
|
||||
prettySource = address[0] # Get only source IP
|
||||
prettyDest = ''
|
||||
logger.info('CONNECT FROM %s', prettySource)
|
||||
async def proxy(self, source: socket.socket, context: 'ssl.SSLContext') -> None:
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# Handshake correct in this point, start SSL connection
|
||||
try:
|
||||
command: bytes = await loop.sock_recv(source, consts.COMMAND_LENGTH)
|
||||
if command == consts.COMMAND_TEST:
|
||||
logger.info('COMMAND: TEST')
|
||||
await loop.sock_sendall(source, b'OK')
|
||||
logger.info('TERMINATED %s', prettySource)
|
||||
return
|
||||
|
||||
if command in (consts.COMMAND_STAT, consts.COMMAND_INFO):
|
||||
logger.info('COMMAND: %s', command.decode())
|
||||
# This is an stats requests
|
||||
await self.stats(
|
||||
full=command == consts.COMMAND_STAT, source=source, address=address
|
||||
)
|
||||
logger.info('TERMINATED %s', prettySource)
|
||||
return
|
||||
|
||||
if command != consts.COMMAND_OPEN:
|
||||
# Invalid command
|
||||
raise Exception()
|
||||
|
||||
# Now, read a TICKET_LENGTH (64) bytes string, that must be [a-zA-Z0-9]{64}
|
||||
ticket: bytes = await source.recv(consts.TICKET_LENGTH)
|
||||
|
||||
# Ticket received, now process it with UDS
|
||||
try:
|
||||
result = await curio.run_in_thread(
|
||||
Proxy.getFromUds, self.cfg, ticket, address
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error('ERROR %s', e.args[0] if e.args else e)
|
||||
await source.sendall(b'ERROR_TICKET')
|
||||
return
|
||||
|
||||
prettyDest = f"{result['host']}:{result['port']}"
|
||||
logger.info('OPEN TUNNEL FROM %s to %s', prettySource, prettyDest)
|
||||
|
||||
except Exception:
|
||||
if consts.DEBUG:
|
||||
logger.exception('COMMAND')
|
||||
logger.error('ERROR from %s', prettySource)
|
||||
await source.sendall(b'ERROR_COMMAND')
|
||||
return
|
||||
|
||||
# Communicate source OPEN is ok
|
||||
await source.sendall(b'OK')
|
||||
|
||||
# Initialize own stats counter
|
||||
counter = stats.Stats(self.ns)
|
||||
|
||||
# Open remote server connection
|
||||
try:
|
||||
destination = await curio.open_connection(
|
||||
result['host'], int(result['port'])
|
||||
)
|
||||
async with curio.TaskGroup(wait=any) as grp:
|
||||
await grp.spawn(
|
||||
Proxy.doProxy, source, destination, counter.as_sent_counter()
|
||||
)
|
||||
await grp.spawn(
|
||||
Proxy.doProxy, destination, source, counter.as_recv_counter()
|
||||
)
|
||||
logger.debug('PROXIES READY')
|
||||
|
||||
logger.debug('Proxies finalized: %s', grp.exceptions)
|
||||
await curio.run_in_thread(
|
||||
Proxy.notifyEndToUds, self.cfg, result['notify'].encode(), counter
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
if consts.DEBUG:
|
||||
logger.exception('OPEN REMOTE')
|
||||
|
||||
logger.error('REMOTE from %s: %s', address, e)
|
||||
finally:
|
||||
counter.close() # So we ensure stats are correctly updated on ns
|
||||
|
||||
logger.info(
|
||||
'TERMINATED %s to %s, s:%s, r:%s, t:%s',
|
||||
prettySource,
|
||||
prettyDest,
|
||||
counter.sent,
|
||||
counter.recv,
|
||||
int(counter.end - counter.start),
|
||||
# Handshake correct in this point, upgrade the connection to TSL and let
|
||||
# the protocol controller do the rest
|
||||
|
||||
# Upgrade connection to SSL, and use asyncio to handle the rest
|
||||
transport: 'asyncio.transports.Transport'
|
||||
protocol: tunnel.TunnelProtocol
|
||||
(transport, protocol) = await loop.connect_accepted_socket( # type: ignore
|
||||
lambda: tunnel.TunnelProtocol(self), source, ssl=context
|
||||
)
|
||||
|
||||
await protocol.finished
|
||||
return
|
||||
|
||||
# try:
|
||||
# command: bytes = await loop.sock_recv(source, consts.COMMAND_LENGTH)
|
||||
# if command == consts.COMMAND_TEST:
|
||||
# logger.info('COMMAND: TEST')
|
||||
# await loop.sock_sendall(source, b'OK')
|
||||
# logger.info('TERMINATED %s', prettySource)
|
||||
# return
|
||||
|
||||
# if command in (consts.COMMAND_STAT, consts.COMMAND_INFO):
|
||||
# logger.info('COMMAND: %s', command.decode())
|
||||
# # This is an stats requests
|
||||
# await self.stats(
|
||||
# full=command == consts.COMMAND_STAT, source=source, address=address
|
||||
# )
|
||||
# logger.info('TERMINATED %s', prettySource)
|
||||
# return
|
||||
|
||||
# if command != consts.COMMAND_OPEN:
|
||||
# # Invalid command
|
||||
# raise Exception()
|
||||
|
||||
# # Now, read a TICKET_LENGTH (64) bytes string, that must be [a-zA-Z0-9]{64}
|
||||
# ticket: bytes = await source.recv(consts.TICKET_LENGTH)
|
||||
|
||||
# # Ticket received, now process it with UDS
|
||||
# try:
|
||||
# result = await curio.run_in_thread(
|
||||
# Proxy.getFromUds, self.cfg, ticket, address
|
||||
# )
|
||||
# except Exception as e:
|
||||
# logger.error('ERROR %s', e.args[0] if e.args else e)
|
||||
# await source.sendall(b'ERROR_TICKET')
|
||||
# return
|
||||
|
||||
# prettyDest = f"{result['host']}:{result['port']}"
|
||||
# logger.info('OPEN TUNNEL FROM %s to %s', prettySource, prettyDest)
|
||||
|
||||
# except Exception:
|
||||
# if consts.DEBUG:
|
||||
# logger.exception('COMMAND')
|
||||
# logger.error('ERROR from %s', prettySource)
|
||||
# await source.sendall(b'ERROR_COMMAND')
|
||||
# return
|
||||
|
||||
# # Communicate source OPEN is ok
|
||||
# await source.sendall(b'OK')
|
||||
|
||||
# # Initialize own stats counter
|
||||
# counter = stats.Stats(self.ns)
|
||||
|
||||
# # Open remote server connection
|
||||
# try:
|
||||
# destination = await curio.open_connection(
|
||||
# result['host'], int(result['port'])
|
||||
# )
|
||||
# async with curio.TaskGroup(wait=any) as grp:
|
||||
# await grp.spawn(
|
||||
# Proxy.doProxy, source, destination, counter.as_sent_counter()
|
||||
# )
|
||||
# await grp.spawn(
|
||||
# Proxy.doProxy, destination, source, counter.as_recv_counter()
|
||||
# )
|
||||
# logger.debug('PROXIES READY')
|
||||
|
||||
# logger.debug('Proxies finalized: %s', grp.exceptions)
|
||||
# await curio.run_in_thread(
|
||||
# Proxy.notifyEndToUds, self.cfg, result['notify'].encode(), counter
|
||||
# )
|
||||
|
||||
# except Exception as e:
|
||||
# if consts.DEBUG:
|
||||
# logger.exception('OPEN REMOTE')
|
||||
|
||||
# logger.error('REMOTE from %s: %s', address, e)
|
||||
# finally:
|
||||
# counter.close() # So we ensure stats are correctly updated on ns
|
||||
|
||||
# logger.info(
|
||||
# 'TERMINATED %s to %s, s:%s, r:%s, t:%s',
|
||||
# prettySource,
|
||||
# prettyDest,
|
||||
# counter.sent,
|
||||
# counter.recv,
|
||||
# int(counter.end - counter.start),
|
||||
# )
|
||||
|
274
tunnel-server/src/uds_tunnel/tunnel.py
Normal file
274
tunnel-server/src/uds_tunnel/tunnel.py
Normal file
@ -0,0 +1,274 @@
|
||||
import asyncio
|
||||
import typing
|
||||
import logging
|
||||
|
||||
|
||||
import requests
|
||||
|
||||
from . import consts
|
||||
from . import config
|
||||
from . import stats
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from . import proxy
|
||||
|
||||
|
||||
# Protocol
|
||||
class TunnelProtocol(asyncio.Protocol):
|
||||
# future to mark eof
|
||||
finished: asyncio.Future
|
||||
# Transport and other side of tunnel
|
||||
transport: 'asyncio.transports.Transport'
|
||||
other_side: 'TunnelProtocol'
|
||||
# Current state
|
||||
runner: typing.Any
|
||||
# Command buffer
|
||||
cmd: bytes
|
||||
# owner Proxy class
|
||||
owner: 'proxy.Proxy'
|
||||
# source of connection
|
||||
source: typing.Tuple[str, int]
|
||||
destination: typing.Tuple[str, int]
|
||||
# Counters & stats related
|
||||
stats_manager: stats.Stats
|
||||
# counter
|
||||
counter: stats.StatsSingleCounter
|
||||
|
||||
def __init__(self, owner: 'proxy.Proxy', other_side: typing.Optional['TunnelProtocol'] = None) -> None:
|
||||
# If no other side is given, we are the server part
|
||||
super().__init__()
|
||||
if other_side:
|
||||
self.other_side = other_side
|
||||
self.stats_manager = other_side.stats_manager
|
||||
#self.counter = self.stats_manager.as_recv_counter()
|
||||
self.runner = self.do_proxy
|
||||
else:
|
||||
self.other_side = self
|
||||
self.stats_manager = stats.Stats(owner.ns)
|
||||
#self.counter = self.stats_manager.as_sent_counter()
|
||||
self.runner = self.do_command
|
||||
|
||||
# transport is undefined until connection_made is called
|
||||
self.finished = asyncio.Future()
|
||||
self.cmd = b''
|
||||
self.owner = owner
|
||||
self.source = ('', 0)
|
||||
self.destination = ('', 0)
|
||||
|
||||
|
||||
def process_open(self):
|
||||
# Open Command has the ticket behind it
|
||||
if len(self.cmd) < consts.TICKET_LENGTH + consts.COMMAND_LENGTH:
|
||||
return # Wait for more data
|
||||
|
||||
# Ticket received, now process it with UDS
|
||||
ticket = self.cmd[consts.COMMAND_LENGTH:]
|
||||
|
||||
# clean up the command
|
||||
self.cmd = b''
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
async def open_other_side() -> None:
|
||||
try:
|
||||
result = await TunnelProtocol.getFromUds(self.owner.cfg, ticket, self.source)
|
||||
except Exception as e:
|
||||
logger.error('ERROR %s', e.args[0] if e.args else e)
|
||||
self.transport.write(b'ERROR_TICKET')
|
||||
self.transport.close() # And force close
|
||||
return
|
||||
|
||||
# store for future use
|
||||
self.destination = (result['host'], int(result['port']))
|
||||
|
||||
logger.info('OPEN TUNNEL FROM %s to %s', self.pretty_source(), self.pretty_destination())
|
||||
|
||||
try:
|
||||
(_, protocol) = await loop.create_connection(
|
||||
lambda: TunnelProtocol(self.owner, self), self.destination[0], self.destination[1]
|
||||
)
|
||||
self.other_side = typing.cast('TunnelProtocol', protocol)
|
||||
|
||||
# send OK to client
|
||||
self.transport.write(b'OK')
|
||||
except Exception as e:
|
||||
logger.error('Error opening connection: %s', e)
|
||||
self.close_connection()
|
||||
|
||||
loop.create_task(open_other_side())
|
||||
self.runner = self.do_proxy
|
||||
|
||||
def process_stats(self, full: bool) -> None:
|
||||
# if pasword is not already received, wait for it
|
||||
if len(self.cmd) < consts.PASSWORD_LENGTH + consts.COMMAND_LENGTH:
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info('COMMAND: %s', self.cmd[:consts.COMMAND_LENGTH].decode())
|
||||
|
||||
# Check valid source ip
|
||||
if self.transport.get_extra_info('peername')[0] not in self.owner.cfg.allow:
|
||||
# Invalid source
|
||||
self.transport.write(b'FORBIDDEN')
|
||||
return
|
||||
|
||||
# Check password
|
||||
passwd = self.cmd[consts.COMMAND_LENGTH:]
|
||||
|
||||
# Clean up the command
|
||||
self.cmd = b''
|
||||
|
||||
if passwd.decode(errors='ignore') != self.owner.cfg.secret:
|
||||
# Invalid password
|
||||
self.transport.write(b'FORBIDDEN')
|
||||
return
|
||||
|
||||
data = stats.GlobalStats.get_stats(self.owner.ns)
|
||||
|
||||
for v in data:
|
||||
logger.debug('SENDING %s', v)
|
||||
self.transport.write(v.encode() + b'\n')
|
||||
|
||||
logger.info('TERMINATED %s', self.pretty_source())
|
||||
finally:
|
||||
self.close_connection()
|
||||
|
||||
def do_command(self, data: bytes) -> None:
|
||||
self.cmd += data
|
||||
if len(self.cmd) >= consts.COMMAND_LENGTH:
|
||||
logger.info('CONNECT FROM %s', self.pretty_source())
|
||||
|
||||
command = self.cmd[:consts.COMMAND_LENGTH]
|
||||
try:
|
||||
if command == consts.COMMAND_OPEN:
|
||||
self.process_open()
|
||||
elif command == consts.COMMAND_TEST:
|
||||
logger.info('COMMAND: TEST')
|
||||
self.transport.write(b'OK')
|
||||
|
||||
self.close_connection()
|
||||
return
|
||||
elif command in (consts.COMMAND_STAT, consts.COMMAND_INFO):
|
||||
# This is an stats requests
|
||||
self.process_stats(full=command == consts.COMMAND_STAT)
|
||||
return
|
||||
else:
|
||||
raise Exception('Invalid command')
|
||||
except Exception:
|
||||
if consts.DEBUG:
|
||||
logger.exception('COMMAND')
|
||||
logger.error('ERROR from %s', self.pretty_source())
|
||||
self.transport.write(b'ERROR_COMMAND')
|
||||
self.transport.close() # end the connection
|
||||
return
|
||||
# if not enough data, wait for more...
|
||||
|
||||
def do_proxy(self, data: bytes) -> None:
|
||||
self.counter.add(len(data))
|
||||
logger.debug('Processing proxy: %s', len(data))
|
||||
self.other_side.transport.write(data)
|
||||
|
||||
|
||||
# inherited from asyncio.Protocol
|
||||
|
||||
def connection_made(self, transport: 'asyncio.transports.BaseTransport') -> None:
|
||||
logger.debug('Connection made: %s', transport.get_extra_info('peername'))
|
||||
|
||||
# We know for sure that the transport is a Transport.
|
||||
self.transport = typing.cast('asyncio.transports.Transport', transport)
|
||||
self.cmd = b''
|
||||
self.source = self.transport.get_extra_info('peername')
|
||||
|
||||
def data_received(self, data: bytes):
|
||||
logger.debug('Data received: %s', len(data))
|
||||
self.runner(data)
|
||||
|
||||
def connection_lost(self, exc: typing.Optional[Exception]) -> None:
|
||||
logger.debug('Connection closed : %s', exc)
|
||||
self.finished.set_result(True)
|
||||
if self.other_side is not self:
|
||||
self.other_side.transport.close()
|
||||
else:
|
||||
self.stats_manager.close()
|
||||
|
||||
# helpers
|
||||
# source address, pretty format
|
||||
def pretty_source(self) -> str:
|
||||
return self.source[0] + ':' + str(self.source[1])
|
||||
|
||||
def pretty_destination(self) -> str:
|
||||
return self.destination[0] + ':' + str(self.destination[1])
|
||||
|
||||
def close_connection(self):
|
||||
self.transport.close()
|
||||
# If destination is not set, it's a command processing (i.e. TEST or STAT)
|
||||
if self.destination[0] != '':
|
||||
logger.info(
|
||||
'TERMINATED %s to %s, s:%s, r:%s, t:%s',
|
||||
self.pretty_source(),
|
||||
self.pretty_destination(),
|
||||
self.stats_manager.sent,
|
||||
self.stats_manager.recv,
|
||||
int(self.stats_manager.end - self.stats_manager.start),
|
||||
)
|
||||
else:
|
||||
logger.info('TERMINATED %s', self.pretty_source())
|
||||
|
||||
@staticmethod
|
||||
def _getUdsUrl(
|
||||
cfg: config.ConfigurationType,
|
||||
ticket: bytes,
|
||||
msg: str,
|
||||
queryParams: typing.Mapping[str, str] = None,
|
||||
) -> typing.MutableMapping[str, typing.Any]:
|
||||
try:
|
||||
url = (
|
||||
cfg.uds_server + '/' + ticket.decode() + '/' + msg + '/' + cfg.uds_token
|
||||
)
|
||||
if queryParams:
|
||||
url += '?' + '&'.join(
|
||||
[f'{key}={value}' for key, value in queryParams.items()]
|
||||
)
|
||||
r = requests.get(
|
||||
url,
|
||||
headers={
|
||||
'content-type': 'application/json',
|
||||
'User-Agent': f'UDSTunnel-{consts.VERSION}',
|
||||
},
|
||||
)
|
||||
if not r.ok:
|
||||
raise Exception(r.content or 'Invalid Ticket (timed out)')
|
||||
|
||||
return r.json()
|
||||
except Exception as e:
|
||||
raise Exception(f'TICKET COMMS ERROR: {ticket.decode()} {msg} {e!s}')
|
||||
|
||||
@staticmethod
|
||||
async def getFromUds(
|
||||
cfg: config.ConfigurationType, ticket: bytes, address: typing.Tuple[str, int]
|
||||
) -> typing.MutableMapping[str, typing.Any]:
|
||||
# Sanity checks
|
||||
if len(ticket) != consts.TICKET_LENGTH:
|
||||
raise Exception(f'TICKET INVALID (len={len(ticket)})')
|
||||
|
||||
for n, i in enumerate(ticket.decode(errors='ignore')):
|
||||
if (
|
||||
(i >= 'a' and i <= 'z')
|
||||
or (i >= '0' and i <= '9')
|
||||
or (i >= 'A' and i <= 'Z')
|
||||
):
|
||||
continue # Correctus
|
||||
raise Exception(f'TICKET INVALID (char {i} at pos {n})')
|
||||
|
||||
return await asyncio.get_event_loop().run_in_executor(None, TunnelProtocol._getUdsUrl, cfg, ticket, address[0])
|
||||
|
||||
@staticmethod
|
||||
async def notifyEndToUds(
|
||||
cfg: config.ConfigurationType, ticket: bytes, counter: stats.Stats
|
||||
) -> None:
|
||||
TunnelProtocol._getUdsUrl(
|
||||
cfg, ticket, 'stop', {'sent': str(counter.sent), 'recv': str(counter.recv)}
|
||||
) # Ignore results
|
||||
|
@ -22,7 +22,7 @@ lognumber = 3
|
||||
address = 0.0.0.0
|
||||
|
||||
# Number of workers. Defaults to 0 (means "as much as cores")
|
||||
workers = 2
|
||||
workers = 1
|
||||
|
||||
# Listening port
|
||||
port = 7777
|
||||
|
@ -98,20 +98,13 @@ async def tunnel_proc_async(
|
||||
) -> None:
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
# Create event for flagging when we have new data
|
||||
event = asyncio.Event()
|
||||
loop.add_reader(pipe.fileno(), event.set)
|
||||
|
||||
tasks: typing.List[asyncio.Task] = []
|
||||
|
||||
async def get_socket() -> typing.Tuple[
|
||||
typing.Optional[socket.socket], typing.Tuple[str, int]
|
||||
]:
|
||||
def get_socket() -> typing.Optional[socket.socket]:
|
||||
try:
|
||||
while True:
|
||||
await event.wait()
|
||||
# Clear back event, for next data
|
||||
event.clear()
|
||||
msg: typing.Optional[
|
||||
typing.Tuple[socket.socket, typing.Tuple[str, int]]
|
||||
] = pipe.recv()
|
||||
@ -121,9 +114,7 @@ async def tunnel_proc_async(
|
||||
|
||||
try:
|
||||
# First, ensure handshake (simple handshake) and command
|
||||
data: bytes = await loop.sock_recv(
|
||||
source, len(consts.HANDSHAKE_V1)
|
||||
)
|
||||
data: bytes = source.recv(len(consts.HANDSHAKE_V1))
|
||||
|
||||
if data != consts.HANDSHAKE_V1:
|
||||
raise Exception() # Invalid handshake
|
||||
@ -135,12 +126,10 @@ async def tunnel_proc_async(
|
||||
source.close()
|
||||
continue
|
||||
|
||||
return msg
|
||||
|
||||
# Process other messages, and retry
|
||||
return source
|
||||
except Exception:
|
||||
logger.exception('Receiving data from parent process')
|
||||
return None, ('', 0)
|
||||
return None
|
||||
|
||||
async def run_server() -> None:
|
||||
# Instantiate a proxy redirector for this process (we only need one per process!!)
|
||||
@ -159,11 +148,11 @@ async def tunnel_proc_async(
|
||||
while True:
|
||||
address: typing.Tuple[str, int] = ('', 0)
|
||||
try:
|
||||
sock, address = await get_socket()
|
||||
sock = await loop.run_in_executor(None, get_socket)
|
||||
if not sock:
|
||||
break # No more sockets, exit
|
||||
logger.debug(f'CONNECTION from {address!r} (pid: {os.getpid()})')
|
||||
tasks.append(asyncio.create_task(tunneler(sock, address, context)))
|
||||
tasks.append(asyncio.create_task(tunneler(sock, context)))
|
||||
except Exception:
|
||||
logger.error('NEGOTIATION ERROR from %s', address[0])
|
||||
|
||||
@ -176,9 +165,6 @@ async def tunnel_proc_async(
|
||||
# Remove finished tasks from list
|
||||
del tasks[:tasks_number]
|
||||
|
||||
# Remove reader from event loop
|
||||
loop.remove_reader(pipe.fileno())
|
||||
|
||||
|
||||
def tunnel_main():
|
||||
cfg = config.read()
|
||||
|
Loading…
x
Reference in New Issue
Block a user