Removing curio from tunneler, so we do not have an unneeded dependency

This commit is contained in:
Adolfo Gómez García 2022-01-26 14:32:41 +01:00
parent 7d8ae689b5
commit 77b6eff8e4
6 changed files with 389 additions and 215 deletions

View File

@ -58,3 +58,4 @@ COMMAND_OPEN = b'OPEN'
COMMAND_TEST = b'TEST'
COMMAND_STAT = b'STAT' # full stats
COMMAND_INFO = b'INFO' # Basic stats, currently same as FULL

View File

@ -39,8 +39,8 @@ class Processes:
def add_child_pid(self):
own_conn, child_conn = multiprocessing.Pipe()
task = multiprocessing.Process(
target=asyncio.run,
args=(self.process(child_conn, self.cfg, self.ns),)
target=Processes.runner,
args=(self.process, child_conn, self.cfg, self.ns),
)
task.start()
logger.debug('ADD CHILD PID: %s', task.pid)
@ -98,4 +98,7 @@ class Processes:
i[2].kill()
except Exception as e:
logger.info('KILLING child %s: %s', i[2], e)
@staticmethod
def runner(proc: ProcessType, conn: 'Connection', cfg: config.ConfigurationType, ns: 'Namespace') -> None:
asyncio.run(proc(conn, cfg, ns))

View File

@ -33,11 +33,8 @@ import socket
import logging
import typing
import requests
from . import config
from . import stats
from . import consts
from . import tunnel
if typing.TYPE_CHECKING:
from multiprocessing.managers import Namespace
@ -54,196 +51,109 @@ class Proxy:
self.cfg = cfg
self.ns = ns
@staticmethod
def _getUdsUrl(
cfg: config.ConfigurationType,
ticket: bytes,
msg: str,
queryParams: typing.Mapping[str, str] = None,
) -> typing.MutableMapping[str, typing.Any]:
try:
url = (
cfg.uds_server + '/' + ticket.decode() + '/' + msg + '/' + cfg.uds_token
)
if queryParams:
url += '?' + '&'.join(
[f'{key}={value}' for key, value in queryParams.items()]
)
r = requests.get(
url,
headers={
'content-type': 'application/json',
'User-Agent': f'UDSTunnel-{consts.VERSION}',
},
)
if not r.ok:
raise Exception(r.content or 'Invalid Ticket (timed out)')
return r.json()
except Exception as e:
raise Exception(f'TICKET COMMS ERROR: {ticket.decode()} {msg} {e!s}')
@staticmethod
def getFromUds(
cfg: config.ConfigurationType, ticket: bytes, address: typing.Tuple[str, int]
) -> typing.MutableMapping[str, typing.Any]:
# Sanity checks
if len(ticket) != consts.TICKET_LENGTH:
raise Exception(f'TICKET INVALID (len={len(ticket)})')
for n, i in enumerate(ticket.decode(errors='ignore')):
if (
(i >= 'a' and i <= 'z')
or (i >= '0' and i <= '9')
or (i >= 'A' and i <= 'Z')
):
continue # Correctus
raise Exception(f'TICKET INVALID (char {i} at pos {n})')
return Proxy._getUdsUrl(cfg, ticket, address[0])
@staticmethod
def notifyEndToUds(
cfg: config.ConfigurationType, ticket: bytes, counter: stats.Stats
) -> None:
Proxy._getUdsUrl(
cfg, ticket, 'stop', {'sent': str(counter.sent), 'recv': str(counter.recv)}
) # Ignore results
# @staticmethod
# async def doProxy(
# source: 'curio.io.Socket',
# destination: 'curio.io.Socket',
# counter: stats.StatsSingleCounter,
# ) -> None:
# try:
# while True:
# data = await source.recv(consts.BUFFER_SIZE)
# if not data:
# break
# await destination.sendall(data)
# counter.add(len(data))
# except Exception:
# # Connection broken, same result as closed for us
# # We must notice that i'ts easy that when closing one part of the tunnel,
# # the other can break (due to some internal data), that's why even log is removed
# # logger.info('CONNECTION LOST FROM %s to %s', source.getsockname(), destination.getpeername())
# pass
# Method responsible of proxying requests
async def __call__(self, source: socket.socket, address: typing.Tuple[str, int], context: 'ssl.SSLContext') -> None:
await self.proxy(source, address, context)
async def __call__(self, source: socket.socket, context: 'ssl.SSLContext') -> None:
await self.proxy(source, context)
async def stats(self, full: bool, source, address: typing.Tuple[str, int]) -> None:
# Check valid source ip
if address[0] not in self.cfg.allow:
# Invalid source
await source.sendall(b'FORBIDDEN')
return
# Check password
passwd = await source.recv(consts.PASSWORD_LENGTH)
if passwd.decode(errors='ignore') != self.cfg.secret:
# Invalid password
await source.sendall(b'FORBIDDEN')
return
data = stats.GlobalStats.get_stats(self.ns)
for v in data:
logger.debug('SENDING %s', v)
await source.sendall(v.encode() + b'\n')
async def proxy(self, source: socket.socket, address: typing.Tuple[str, int], context: 'ssl.SSLContext') -> None:
prettySource = address[0] # Get only source IP
prettyDest = ''
logger.info('CONNECT FROM %s', prettySource)
async def proxy(self, source: socket.socket, context: 'ssl.SSLContext') -> None:
loop = asyncio.get_event_loop()
# Handshake correct in this point, start SSL connection
try:
command: bytes = await loop.sock_recv(source, consts.COMMAND_LENGTH)
if command == consts.COMMAND_TEST:
logger.info('COMMAND: TEST')
await loop.sock_sendall(source, b'OK')
logger.info('TERMINATED %s', prettySource)
return
if command in (consts.COMMAND_STAT, consts.COMMAND_INFO):
logger.info('COMMAND: %s', command.decode())
# This is an stats requests
await self.stats(
full=command == consts.COMMAND_STAT, source=source, address=address
)
logger.info('TERMINATED %s', prettySource)
return
if command != consts.COMMAND_OPEN:
# Invalid command
raise Exception()
# Now, read a TICKET_LENGTH (64) bytes string, that must be [a-zA-Z0-9]{64}
ticket: bytes = await source.recv(consts.TICKET_LENGTH)
# Ticket received, now process it with UDS
try:
result = await curio.run_in_thread(
Proxy.getFromUds, self.cfg, ticket, address
)
except Exception as e:
logger.error('ERROR %s', e.args[0] if e.args else e)
await source.sendall(b'ERROR_TICKET')
return
prettyDest = f"{result['host']}:{result['port']}"
logger.info('OPEN TUNNEL FROM %s to %s', prettySource, prettyDest)
except Exception:
if consts.DEBUG:
logger.exception('COMMAND')
logger.error('ERROR from %s', prettySource)
await source.sendall(b'ERROR_COMMAND')
return
# Communicate source OPEN is ok
await source.sendall(b'OK')
# Initialize own stats counter
counter = stats.Stats(self.ns)
# Open remote server connection
try:
destination = await curio.open_connection(
result['host'], int(result['port'])
)
async with curio.TaskGroup(wait=any) as grp:
await grp.spawn(
Proxy.doProxy, source, destination, counter.as_sent_counter()
)
await grp.spawn(
Proxy.doProxy, destination, source, counter.as_recv_counter()
)
logger.debug('PROXIES READY')
logger.debug('Proxies finalized: %s', grp.exceptions)
await curio.run_in_thread(
Proxy.notifyEndToUds, self.cfg, result['notify'].encode(), counter
)
except Exception as e:
if consts.DEBUG:
logger.exception('OPEN REMOTE')
logger.error('REMOTE from %s: %s', address, e)
finally:
counter.close() # So we ensure stats are correctly updated on ns
logger.info(
'TERMINATED %s to %s, s:%s, r:%s, t:%s',
prettySource,
prettyDest,
counter.sent,
counter.recv,
int(counter.end - counter.start),
# Handshake correct in this point, upgrade the connection to TSL and let
# the protocol controller do the rest
# Upgrade connection to SSL, and use asyncio to handle the rest
transport: 'asyncio.transports.Transport'
protocol: tunnel.TunnelProtocol
(transport, protocol) = await loop.connect_accepted_socket( # type: ignore
lambda: tunnel.TunnelProtocol(self), source, ssl=context
)
await protocol.finished
return
# try:
# command: bytes = await loop.sock_recv(source, consts.COMMAND_LENGTH)
# if command == consts.COMMAND_TEST:
# logger.info('COMMAND: TEST')
# await loop.sock_sendall(source, b'OK')
# logger.info('TERMINATED %s', prettySource)
# return
# if command in (consts.COMMAND_STAT, consts.COMMAND_INFO):
# logger.info('COMMAND: %s', command.decode())
# # This is an stats requests
# await self.stats(
# full=command == consts.COMMAND_STAT, source=source, address=address
# )
# logger.info('TERMINATED %s', prettySource)
# return
# if command != consts.COMMAND_OPEN:
# # Invalid command
# raise Exception()
# # Now, read a TICKET_LENGTH (64) bytes string, that must be [a-zA-Z0-9]{64}
# ticket: bytes = await source.recv(consts.TICKET_LENGTH)
# # Ticket received, now process it with UDS
# try:
# result = await curio.run_in_thread(
# Proxy.getFromUds, self.cfg, ticket, address
# )
# except Exception as e:
# logger.error('ERROR %s', e.args[0] if e.args else e)
# await source.sendall(b'ERROR_TICKET')
# return
# prettyDest = f"{result['host']}:{result['port']}"
# logger.info('OPEN TUNNEL FROM %s to %s', prettySource, prettyDest)
# except Exception:
# if consts.DEBUG:
# logger.exception('COMMAND')
# logger.error('ERROR from %s', prettySource)
# await source.sendall(b'ERROR_COMMAND')
# return
# # Communicate source OPEN is ok
# await source.sendall(b'OK')
# # Initialize own stats counter
# counter = stats.Stats(self.ns)
# # Open remote server connection
# try:
# destination = await curio.open_connection(
# result['host'], int(result['port'])
# )
# async with curio.TaskGroup(wait=any) as grp:
# await grp.spawn(
# Proxy.doProxy, source, destination, counter.as_sent_counter()
# )
# await grp.spawn(
# Proxy.doProxy, destination, source, counter.as_recv_counter()
# )
# logger.debug('PROXIES READY')
# logger.debug('Proxies finalized: %s', grp.exceptions)
# await curio.run_in_thread(
# Proxy.notifyEndToUds, self.cfg, result['notify'].encode(), counter
# )
# except Exception as e:
# if consts.DEBUG:
# logger.exception('OPEN REMOTE')
# logger.error('REMOTE from %s: %s', address, e)
# finally:
# counter.close() # So we ensure stats are correctly updated on ns
# logger.info(
# 'TERMINATED %s to %s, s:%s, r:%s, t:%s',
# prettySource,
# prettyDest,
# counter.sent,
# counter.recv,
# int(counter.end - counter.start),
# )

View File

@ -0,0 +1,274 @@
import asyncio
import typing
import logging
import requests
from . import consts
from . import config
from . import stats
logger = logging.getLogger(__name__)
if typing.TYPE_CHECKING:
from . import proxy
# Protocol
class TunnelProtocol(asyncio.Protocol):
# future to mark eof
finished: asyncio.Future
# Transport and other side of tunnel
transport: 'asyncio.transports.Transport'
other_side: 'TunnelProtocol'
# Current state
runner: typing.Any
# Command buffer
cmd: bytes
# owner Proxy class
owner: 'proxy.Proxy'
# source of connection
source: typing.Tuple[str, int]
destination: typing.Tuple[str, int]
# Counters & stats related
stats_manager: stats.Stats
# counter
counter: stats.StatsSingleCounter
def __init__(self, owner: 'proxy.Proxy', other_side: typing.Optional['TunnelProtocol'] = None) -> None:
# If no other side is given, we are the server part
super().__init__()
if other_side:
self.other_side = other_side
self.stats_manager = other_side.stats_manager
#self.counter = self.stats_manager.as_recv_counter()
self.runner = self.do_proxy
else:
self.other_side = self
self.stats_manager = stats.Stats(owner.ns)
#self.counter = self.stats_manager.as_sent_counter()
self.runner = self.do_command
# transport is undefined until connection_made is called
self.finished = asyncio.Future()
self.cmd = b''
self.owner = owner
self.source = ('', 0)
self.destination = ('', 0)
def process_open(self):
# Open Command has the ticket behind it
if len(self.cmd) < consts.TICKET_LENGTH + consts.COMMAND_LENGTH:
return # Wait for more data
# Ticket received, now process it with UDS
ticket = self.cmd[consts.COMMAND_LENGTH:]
# clean up the command
self.cmd = b''
loop = asyncio.get_event_loop()
async def open_other_side() -> None:
try:
result = await TunnelProtocol.getFromUds(self.owner.cfg, ticket, self.source)
except Exception as e:
logger.error('ERROR %s', e.args[0] if e.args else e)
self.transport.write(b'ERROR_TICKET')
self.transport.close() # And force close
return
# store for future use
self.destination = (result['host'], int(result['port']))
logger.info('OPEN TUNNEL FROM %s to %s', self.pretty_source(), self.pretty_destination())
try:
(_, protocol) = await loop.create_connection(
lambda: TunnelProtocol(self.owner, self), self.destination[0], self.destination[1]
)
self.other_side = typing.cast('TunnelProtocol', protocol)
# send OK to client
self.transport.write(b'OK')
except Exception as e:
logger.error('Error opening connection: %s', e)
self.close_connection()
loop.create_task(open_other_side())
self.runner = self.do_proxy
def process_stats(self, full: bool) -> None:
# if pasword is not already received, wait for it
if len(self.cmd) < consts.PASSWORD_LENGTH + consts.COMMAND_LENGTH:
return
try:
logger.info('COMMAND: %s', self.cmd[:consts.COMMAND_LENGTH].decode())
# Check valid source ip
if self.transport.get_extra_info('peername')[0] not in self.owner.cfg.allow:
# Invalid source
self.transport.write(b'FORBIDDEN')
return
# Check password
passwd = self.cmd[consts.COMMAND_LENGTH:]
# Clean up the command
self.cmd = b''
if passwd.decode(errors='ignore') != self.owner.cfg.secret:
# Invalid password
self.transport.write(b'FORBIDDEN')
return
data = stats.GlobalStats.get_stats(self.owner.ns)
for v in data:
logger.debug('SENDING %s', v)
self.transport.write(v.encode() + b'\n')
logger.info('TERMINATED %s', self.pretty_source())
finally:
self.close_connection()
def do_command(self, data: bytes) -> None:
self.cmd += data
if len(self.cmd) >= consts.COMMAND_LENGTH:
logger.info('CONNECT FROM %s', self.pretty_source())
command = self.cmd[:consts.COMMAND_LENGTH]
try:
if command == consts.COMMAND_OPEN:
self.process_open()
elif command == consts.COMMAND_TEST:
logger.info('COMMAND: TEST')
self.transport.write(b'OK')
self.close_connection()
return
elif command in (consts.COMMAND_STAT, consts.COMMAND_INFO):
# This is an stats requests
self.process_stats(full=command == consts.COMMAND_STAT)
return
else:
raise Exception('Invalid command')
except Exception:
if consts.DEBUG:
logger.exception('COMMAND')
logger.error('ERROR from %s', self.pretty_source())
self.transport.write(b'ERROR_COMMAND')
self.transport.close() # end the connection
return
# if not enough data, wait for more...
def do_proxy(self, data: bytes) -> None:
self.counter.add(len(data))
logger.debug('Processing proxy: %s', len(data))
self.other_side.transport.write(data)
# inherited from asyncio.Protocol
def connection_made(self, transport: 'asyncio.transports.BaseTransport') -> None:
logger.debug('Connection made: %s', transport.get_extra_info('peername'))
# We know for sure that the transport is a Transport.
self.transport = typing.cast('asyncio.transports.Transport', transport)
self.cmd = b''
self.source = self.transport.get_extra_info('peername')
def data_received(self, data: bytes):
logger.debug('Data received: %s', len(data))
self.runner(data)
def connection_lost(self, exc: typing.Optional[Exception]) -> None:
logger.debug('Connection closed : %s', exc)
self.finished.set_result(True)
if self.other_side is not self:
self.other_side.transport.close()
else:
self.stats_manager.close()
# helpers
# source address, pretty format
def pretty_source(self) -> str:
return self.source[0] + ':' + str(self.source[1])
def pretty_destination(self) -> str:
return self.destination[0] + ':' + str(self.destination[1])
def close_connection(self):
self.transport.close()
# If destination is not set, it's a command processing (i.e. TEST or STAT)
if self.destination[0] != '':
logger.info(
'TERMINATED %s to %s, s:%s, r:%s, t:%s',
self.pretty_source(),
self.pretty_destination(),
self.stats_manager.sent,
self.stats_manager.recv,
int(self.stats_manager.end - self.stats_manager.start),
)
else:
logger.info('TERMINATED %s', self.pretty_source())
@staticmethod
def _getUdsUrl(
cfg: config.ConfigurationType,
ticket: bytes,
msg: str,
queryParams: typing.Mapping[str, str] = None,
) -> typing.MutableMapping[str, typing.Any]:
try:
url = (
cfg.uds_server + '/' + ticket.decode() + '/' + msg + '/' + cfg.uds_token
)
if queryParams:
url += '?' + '&'.join(
[f'{key}={value}' for key, value in queryParams.items()]
)
r = requests.get(
url,
headers={
'content-type': 'application/json',
'User-Agent': f'UDSTunnel-{consts.VERSION}',
},
)
if not r.ok:
raise Exception(r.content or 'Invalid Ticket (timed out)')
return r.json()
except Exception as e:
raise Exception(f'TICKET COMMS ERROR: {ticket.decode()} {msg} {e!s}')
@staticmethod
async def getFromUds(
cfg: config.ConfigurationType, ticket: bytes, address: typing.Tuple[str, int]
) -> typing.MutableMapping[str, typing.Any]:
# Sanity checks
if len(ticket) != consts.TICKET_LENGTH:
raise Exception(f'TICKET INVALID (len={len(ticket)})')
for n, i in enumerate(ticket.decode(errors='ignore')):
if (
(i >= 'a' and i <= 'z')
or (i >= '0' and i <= '9')
or (i >= 'A' and i <= 'Z')
):
continue # Correctus
raise Exception(f'TICKET INVALID (char {i} at pos {n})')
return await asyncio.get_event_loop().run_in_executor(None, TunnelProtocol._getUdsUrl, cfg, ticket, address[0])
@staticmethod
async def notifyEndToUds(
cfg: config.ConfigurationType, ticket: bytes, counter: stats.Stats
) -> None:
TunnelProtocol._getUdsUrl(
cfg, ticket, 'stop', {'sent': str(counter.sent), 'recv': str(counter.recv)}
) # Ignore results

View File

@ -22,7 +22,7 @@ lognumber = 3
address = 0.0.0.0
# Number of workers. Defaults to 0 (means "as much as cores")
workers = 2
workers = 1
# Listening port
port = 7777

View File

@ -98,20 +98,13 @@ async def tunnel_proc_async(
) -> None:
loop = asyncio.get_event_loop()
# Create event for flagging when we have new data
event = asyncio.Event()
loop.add_reader(pipe.fileno(), event.set)
tasks: typing.List[asyncio.Task] = []
async def get_socket() -> typing.Tuple[
typing.Optional[socket.socket], typing.Tuple[str, int]
]:
def get_socket() -> typing.Optional[socket.socket]:
try:
while True:
await event.wait()
# Clear back event, for next data
event.clear()
msg: typing.Optional[
typing.Tuple[socket.socket, typing.Tuple[str, int]]
] = pipe.recv()
@ -121,9 +114,7 @@ async def tunnel_proc_async(
try:
# First, ensure handshake (simple handshake) and command
data: bytes = await loop.sock_recv(
source, len(consts.HANDSHAKE_V1)
)
data: bytes = source.recv(len(consts.HANDSHAKE_V1))
if data != consts.HANDSHAKE_V1:
raise Exception() # Invalid handshake
@ -135,12 +126,10 @@ async def tunnel_proc_async(
source.close()
continue
return msg
# Process other messages, and retry
return source
except Exception:
logger.exception('Receiving data from parent process')
return None, ('', 0)
return None
async def run_server() -> None:
# Instantiate a proxy redirector for this process (we only need one per process!!)
@ -159,11 +148,11 @@ async def tunnel_proc_async(
while True:
address: typing.Tuple[str, int] = ('', 0)
try:
sock, address = await get_socket()
sock = await loop.run_in_executor(None, get_socket)
if not sock:
break # No more sockets, exit
logger.debug(f'CONNECTION from {address!r} (pid: {os.getpid()})')
tasks.append(asyncio.create_task(tunneler(sock, address, context)))
tasks.append(asyncio.create_task(tunneler(sock, context)))
except Exception:
logger.error('NEGOTIATION ERROR from %s', address[0])
@ -176,9 +165,6 @@ async def tunnel_proc_async(
# Remove finished tasks from list
del tasks[:tasks_number]
# Remove reader from event loop
loop.remove_reader(pipe.fileno())
def tunnel_main():
cfg = config.read()