forked from shaba/openuds
new tunnel server ready for testing phase
This commit is contained in:
parent
971e5984d9
commit
d6a8639b18
@ -39,8 +39,12 @@ from .consts import CONFIGFILE
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ConfigurationType(typing.NamedTuple):
|
||||
pidfile: str
|
||||
|
||||
log_level: str
|
||||
log_file: str
|
||||
log_size: int
|
||||
log_number: int
|
||||
|
||||
listen_address: str
|
||||
listen_port: int
|
||||
@ -73,10 +77,18 @@ def read() -> ConfigurationType:
|
||||
h.update(uds.get('secret', '').encode())
|
||||
secret = h.hexdigest()
|
||||
|
||||
|
||||
try:
|
||||
# log size
|
||||
logsize: str = uds.get('logsize', '32M')
|
||||
if logsize[-1] == 'M':
|
||||
logsize = logsize[:-1]
|
||||
return ConfigurationType(
|
||||
pidfile=uds.get('pidfile', '/dev/null'),
|
||||
log_level=uds.get('loglevel', 'ERROR'),
|
||||
log_file=uds.get('logfile', ''),
|
||||
log_size=int(logsize)*1024*1024,
|
||||
log_number=int(uds.get('lognumber', '3')),
|
||||
listen_address=uds.get('address', '0.0.0.0'),
|
||||
listen_port=int(uds.get('port', '443')),
|
||||
workers=int(uds.get('workers', '0')) or multiprocessing.cpu_count(),
|
||||
|
@ -43,6 +43,7 @@ if typing.TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Proxy:
|
||||
cfg: config.ConfigurationType
|
||||
ns: 'Namespace'
|
||||
@ -52,13 +53,19 @@ class Proxy:
|
||||
self.ns = ns
|
||||
|
||||
@staticmethod
|
||||
def getFromUds(cfg: config.ConfigurationType, ticket: bytes) -> typing.MutableMapping[str, typing.Any]:
|
||||
def getFromUds(
|
||||
cfg: config.ConfigurationType, ticket: bytes
|
||||
) -> typing.MutableMapping[str, typing.Any]:
|
||||
# Sanity checks
|
||||
if len(ticket) != consts.TICKET_LENGTH:
|
||||
raise Exception(f'TICKET INVALID (len={len(ticket)})')
|
||||
|
||||
for n, i in enumerate(ticket.decode(errors='ignore')):
|
||||
if (i >= 'a' and i <= 'z') or (i >= '0' and i <= '9') or (i >= 'A' and i <= 'Z'):
|
||||
if (
|
||||
(i >= 'a' and i <= 'z')
|
||||
or (i >= '0' and i <= '9')
|
||||
or (i >= 'A' and i <= 'Z')
|
||||
):
|
||||
continue # Correctus
|
||||
raise Exception(f'TICKET INVALID (char {i} at pos {n})')
|
||||
|
||||
@ -68,7 +75,7 @@ class Proxy:
|
||||
# raise Exception(f'TICKET INVALID (check {r.json})')
|
||||
return {
|
||||
'host': ['172.27.1.15', '172.27.0.10'][int(ticket[0]) - 0x30],
|
||||
'port': '3389'
|
||||
'port': '3389',
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@ -98,17 +105,16 @@ class Proxy:
|
||||
await source.sendall(b'FORBIDDEN')
|
||||
return
|
||||
|
||||
logger.info('STATS TO %s', address)
|
||||
|
||||
data = stats.GlobalStats.get_stats(self.ns)
|
||||
|
||||
for v in data:
|
||||
logger.debug('SENDING %s', v)
|
||||
await source.sendall(v.encode() + b'\n')
|
||||
|
||||
|
||||
async def proxy(self, source, address: typing.Tuple[str, int]) -> None:
|
||||
logger.info('OPEN FROM %s', address)
|
||||
pretty_adress = address[0] # Get only source IP
|
||||
|
||||
logger.info('CONNECT FROM %s', pretty_adress)
|
||||
|
||||
try:
|
||||
# First, ensure handshake (simple handshake) and command
|
||||
@ -128,12 +134,16 @@ class Proxy:
|
||||
# Handshake correct, get the command (4 bytes)
|
||||
command: bytes = await source.recv(consts.COMMAND_LENGTH)
|
||||
if command == consts.COMMAND_TEST:
|
||||
logger.info('COMMAND: TEST')
|
||||
await source.sendall(b'OK')
|
||||
return
|
||||
|
||||
if command in (consts.COMMAND_STAT, consts.COMMAND_INFO):
|
||||
logger.info('COMMAND: %s', command.decode())
|
||||
# This is an stats requests
|
||||
await self.stats(full=command==consts.COMMAND_STAT, source=source, address=address)
|
||||
await self.stats(
|
||||
full=command == consts.COMMAND_STAT, source=source, address=address
|
||||
)
|
||||
return
|
||||
|
||||
if command != consts.COMMAND_OPEN:
|
||||
@ -147,19 +157,19 @@ class Proxy:
|
||||
try:
|
||||
result = await curio.run_in_thread(Proxy.getFromUds, self.cfg, ticket)
|
||||
except Exception as e:
|
||||
logger.error('%s', e.args[0] if e.args else e)
|
||||
logger.error('ERROR %s', e.args[0] if e.args else e)
|
||||
raise
|
||||
|
||||
print(f'Result: {result}')
|
||||
|
||||
# Invalid result from UDS, not allowed to connect
|
||||
if not result:
|
||||
raise Exception()
|
||||
raise Exception('INVALID TICKET')
|
||||
|
||||
logger.info('OPEN TUNNEL FROM %s to %s:%s', pretty_adress, result['host'], result['port'])
|
||||
|
||||
except Exception:
|
||||
if consts.DEBUG:
|
||||
logger.exception('COMMAND')
|
||||
logger.error('COMMAND from %s', address)
|
||||
logger.error('ERROR from %s', address)
|
||||
await source.sendall(b'COMMAND_ERROR')
|
||||
return
|
||||
|
||||
@ -171,11 +181,17 @@ class Proxy:
|
||||
|
||||
# Open remote server connection
|
||||
try:
|
||||
destination = await curio.open_connection(result['host'], int(result['port']))
|
||||
destination = await curio.open_connection(
|
||||
result['host'], int(result['port'])
|
||||
)
|
||||
async with curio.TaskGroup(wait=any) as grp:
|
||||
await grp.spawn(Proxy.doProxy, source, destination, counter.as_sent_counter())
|
||||
await grp.spawn(Proxy.doProxy, destination, source, counter.as_recv_counter())
|
||||
logger.debug('Launched proxies')
|
||||
await grp.spawn(
|
||||
Proxy.doProxy, source, destination, counter.as_sent_counter()
|
||||
)
|
||||
await grp.spawn(
|
||||
Proxy.doProxy, destination, source, counter.as_recv_counter()
|
||||
)
|
||||
logger.debug('PROXIES READY')
|
||||
|
||||
logger.debug('Proxies finalized: %s', grp.exceptions)
|
||||
|
||||
@ -185,7 +201,6 @@ class Proxy:
|
||||
|
||||
logger.error('REMOTE from %s: %s', address, e)
|
||||
finally:
|
||||
counter.close()
|
||||
counter.close() # So we ensure stats are correctly updated on ns
|
||||
|
||||
|
||||
logger.info('CLOSED FROM %s', address)
|
||||
logger.info('TERMINATED %s', ':'.join(str(i) for i in address))
|
||||
|
@ -1,7 +1,16 @@
|
||||
# Sample testing UDS tunnel configuration
|
||||
# Sample DS tunnel configuration
|
||||
|
||||
# Log level, valid are DEBUG, INFO, WARN, ERROR
|
||||
pidfile = /tmp/udstunnel.pid
|
||||
|
||||
# Log level, valid are DEBUG, INFO, WARN, ERROR. Defaults to ERROR
|
||||
loglevel = DEBUG
|
||||
# Log file, No default
|
||||
logfile = /tmp/tunnel.log
|
||||
# Max log size before rotating it. Defaults to 32 MB.
|
||||
# The value is in MB. You can include or not the M string at end.
|
||||
logsize = 20M
|
||||
# Number of backup logs to keep. Defaults to 3
|
||||
lognumber = 3
|
||||
|
||||
# Listen address. Defaults to 0.0.0.0
|
||||
address = 0.0.0.0
|
||||
@ -12,16 +21,17 @@ workers = 2
|
||||
# Listening port
|
||||
port = 7777
|
||||
|
||||
# SSL Related parameters
|
||||
# SSL Related parameters.
|
||||
ssl_certificate = tests/testing.pem
|
||||
ssl_certificate_key = tests/testing.key
|
||||
# ssl_ciphers and ssl_dhparam are optional.
|
||||
ssl_ciphers = ECDHE-RSA-AES256-GCM-SHA512:DHE-RSA-AES256-GCM-SHA512:ECDHE-RSA-AES256-GCM-SHA384:DHE-RSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-SHA384
|
||||
ssl_dhparam = /etc/certs/dhparam.pem
|
||||
|
||||
# UDS server location
|
||||
# UDS server location. https NEEDS valid certificate
|
||||
uds_server = http://172.27.0.1:8000
|
||||
|
||||
# Secret to get access to admin commands
|
||||
# Secret to get access to admin commands. No default for this.
|
||||
# Admin commands and only allowed from localhost
|
||||
# So, in order to allow this commands, ensure listen address allows connections from localhost
|
||||
secret = MySecret
|
||||
|
@ -29,10 +29,11 @@
|
||||
'''
|
||||
@author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
'''
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import multiprocessing
|
||||
import threading
|
||||
import signal
|
||||
import socket
|
||||
import logging
|
||||
import typing
|
||||
@ -54,28 +55,44 @@ BACKLOG = 100
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
do_stop = False
|
||||
|
||||
|
||||
def stop_signal(signum, frame):
|
||||
global do_stop
|
||||
do_stop = True
|
||||
logger.debug('SIGNAL %s, frame: %s', signum, frame)
|
||||
|
||||
|
||||
def setup_log(cfg: config.ConfigurationType) -> None:
|
||||
# Setup basic logging
|
||||
log = logging.getLogger()
|
||||
log.setLevel(logging.DEBUG)
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
handler.setLevel(logging.DEBUG)
|
||||
formatter = logging.Formatter(
|
||||
'%(levelname)s - %(message)s'
|
||||
) # Basic log format, nice for syslog
|
||||
handler.setFormatter(formatter)
|
||||
log.addHandler(handler)
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
# Update logging if needed
|
||||
if cfg.log_file:
|
||||
fileh = logging.FileHandler(cfg.log_file, 'a')
|
||||
fileh = RotatingFileHandler(
|
||||
filename=cfg.log_file,
|
||||
mode='a',
|
||||
maxBytes=cfg.log_size,
|
||||
backupCount=cfg.log_number,
|
||||
)
|
||||
formatter = logging.Formatter(consts.LOGFORMAT)
|
||||
fileh.setFormatter(formatter)
|
||||
log = logging.getLogger()
|
||||
for hdlr in log.handlers[:]:
|
||||
log.removeHandler(hdlr)
|
||||
log.setLevel(cfg.log_level)
|
||||
# for hdlr in log.handlers[:]:
|
||||
# log.removeHandler(hdlr)
|
||||
log.addHandler(fileh)
|
||||
else:
|
||||
# Setup basic logging
|
||||
log = logging.getLogger()
|
||||
log.setLevel(cfg.log_level)
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
handler.setLevel(cfg.log_level)
|
||||
formatter = logging.Formatter(
|
||||
'%(levelname)s - %(message)s'
|
||||
) # Basic log format, nice for syslog
|
||||
handler.setFormatter(formatter)
|
||||
log.addHandler(handler)
|
||||
|
||||
|
||||
async def tunnel_proc_async(
|
||||
@ -109,31 +126,46 @@ async def tunnel_proc_async(
|
||||
context.load_dh_params(cfg.ssl_dhparam)
|
||||
|
||||
while True:
|
||||
address = ('', '')
|
||||
try:
|
||||
sock, address = await curio.run_in_thread(get_socket, pipe)
|
||||
if not sock:
|
||||
break
|
||||
logger.debug(
|
||||
f'{multiprocessing.current_process().pid!r}: Got new connection from {address!r}'
|
||||
f'CONNECTION from {address!r} (pid: {os.getpid()})'
|
||||
)
|
||||
sock = await context.wrap_socket(
|
||||
curio.io.Socket(sock), server_side=True
|
||||
)
|
||||
sock = await context.wrap_socket(curio.io.Socket(sock), server_side=True)
|
||||
await group.spawn(tunneler, sock, address)
|
||||
del sock
|
||||
except Exception as e:
|
||||
logger.error('SETING UP CONNECTION: %s', e)
|
||||
except Exception:
|
||||
logger.error('NEGOTIATION ERROR from %s', address[0])
|
||||
|
||||
async with curio.TaskGroup() as tg:
|
||||
await tg.spawn(run_server, pipe, cfg, tg)
|
||||
# Reap all of the children tasks as they complete
|
||||
async for task in tg:
|
||||
logger.debug(f'Deleting {task!r}')
|
||||
logger.debug(f'REMOVING async task {task!r}')
|
||||
task.joined = True
|
||||
del task
|
||||
|
||||
|
||||
def tunnel_main():
|
||||
cfg = config.read()
|
||||
setup_log(cfg)
|
||||
|
||||
# Create pid file
|
||||
try:
|
||||
setup_log(cfg)
|
||||
with open(cfg.pidfile, mode='w') as f:
|
||||
f.write(str(os.getpid()))
|
||||
except Exception as e:
|
||||
sys.stderr.write(f'Tunnel startup error: {e}\n')
|
||||
return
|
||||
|
||||
# Setup signal handlers
|
||||
signal.signal(signal.SIGINT, stop_signal)
|
||||
signal.signal(signal.SIGTERM, stop_signal)
|
||||
|
||||
# Creates as many processes and pipes as required
|
||||
child: typing.List[
|
||||
@ -149,6 +181,7 @@ def tunnel_main():
|
||||
args=(tunnel_proc_async, child_conn, cfg, stats_collector.ns),
|
||||
)
|
||||
task.start()
|
||||
logger.debug('ADD CHILD PID: %s', task.pid)
|
||||
child.append((own_conn, task, psutil.Process(task.pid)))
|
||||
|
||||
def best_child() -> 'Connection':
|
||||
@ -168,24 +201,39 @@ def tunnel_main():
|
||||
try:
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, True)
|
||||
except (AttributeError, OSError) as e:
|
||||
logger.warning('socket.REUSEPORT not available', exc_info=True)
|
||||
logger.warning('socket.REUSEPORT not available')
|
||||
|
||||
sock.settimeout(3.0) # So we can check for stop from time to time
|
||||
sock.bind((cfg.listen_address, cfg.listen_port))
|
||||
sock.listen(BACKLOG)
|
||||
while True:
|
||||
client, addr = sock.accept()
|
||||
# Select BEST process for sending this new connection
|
||||
best_child().send(message.Message(message.Command.TUNNEL, (client, addr)))
|
||||
del client # Ensure socket is controlled on child process
|
||||
except Exception:
|
||||
logger.exception('Mar')
|
||||
while not do_stop:
|
||||
try:
|
||||
client, addr = sock.accept()
|
||||
# Select BEST process for sending this new connection
|
||||
best_child().send(
|
||||
message.Message(message.Command.TUNNEL, (client, addr))
|
||||
)
|
||||
del client # Ensure socket is controlled on child process
|
||||
except socket.timeout:
|
||||
pass # Continue and retry
|
||||
except Exception as e:
|
||||
logger.error('LOOP: %s', e)
|
||||
except Exception as e:
|
||||
logger.error('MAIN: %s', e)
|
||||
pass
|
||||
|
||||
logger.info('Exiting tunnel server')
|
||||
|
||||
if sock:
|
||||
sock.close()
|
||||
|
||||
# Try to stop running childs
|
||||
for i in child:
|
||||
try:
|
||||
i[2].kill()
|
||||
except Exception as e:
|
||||
logger.info('KILLING child %s: %s', i[2], e)
|
||||
|
||||
logger.info('FINISHED')
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
Loading…
x
Reference in New Issue
Block a user