From d6a8639b1874fce8727029f203023f613c3b6c38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adolfo=20G=C3=B3mez=20Garc=C3=ADa?= Date: Thu, 14 Jan 2021 06:01:06 +0100 Subject: [PATCH] new tunnel server ready for testing phase --- tunnel-server/src/uds_tunnel/config.py | 12 +++ tunnel-server/src/uds_tunnel/proxy.py | 55 ++++++++----- tunnel-server/src/udstunnel.cfg | 20 +++-- tunnel-server/src/udstunnel.py | 108 ++++++++++++++++++------- 4 files changed, 140 insertions(+), 55 deletions(-) diff --git a/tunnel-server/src/uds_tunnel/config.py b/tunnel-server/src/uds_tunnel/config.py index 40b8f2cf..b9f37ee3 100644 --- a/tunnel-server/src/uds_tunnel/config.py +++ b/tunnel-server/src/uds_tunnel/config.py @@ -39,8 +39,12 @@ from .consts import CONFIGFILE logger = logging.getLogger(__name__) class ConfigurationType(typing.NamedTuple): + pidfile: str + log_level: str log_file: str + log_size: int + log_number: int listen_address: str listen_port: int @@ -73,10 +77,18 @@ def read() -> ConfigurationType: h.update(uds.get('secret', '').encode()) secret = h.hexdigest() + try: + # log size + logsize: str = uds.get('logsize', '32M') + if logsize[-1] == 'M': + logsize = logsize[:-1] return ConfigurationType( + pidfile=uds.get('pidfile', '/dev/null'), log_level=uds.get('loglevel', 'ERROR'), log_file=uds.get('logfile', ''), + log_size=int(logsize)*1024*1024, + log_number=int(uds.get('lognumber', '3')), listen_address=uds.get('address', '0.0.0.0'), listen_port=int(uds.get('port', '443')), workers=int(uds.get('workers', '0')) or multiprocessing.cpu_count(), diff --git a/tunnel-server/src/uds_tunnel/proxy.py b/tunnel-server/src/uds_tunnel/proxy.py index b08ed592..8c3d9b9e 100644 --- a/tunnel-server/src/uds_tunnel/proxy.py +++ b/tunnel-server/src/uds_tunnel/proxy.py @@ -43,6 +43,7 @@ if typing.TYPE_CHECKING: logger = logging.getLogger(__name__) + class Proxy: cfg: config.ConfigurationType ns: 'Namespace' @@ -52,13 +53,19 @@ class Proxy: self.ns = ns @staticmethod - def getFromUds(cfg: config.ConfigurationType, ticket: bytes) -> typing.MutableMapping[str, typing.Any]: + def getFromUds( + cfg: config.ConfigurationType, ticket: bytes + ) -> typing.MutableMapping[str, typing.Any]: # Sanity checks if len(ticket) != consts.TICKET_LENGTH: raise Exception(f'TICKET INVALID (len={len(ticket)})') for n, i in enumerate(ticket.decode(errors='ignore')): - if (i >= 'a' and i <= 'z') or (i >= '0' and i <= '9') or (i >= 'A' and i <= 'Z'): + if ( + (i >= 'a' and i <= 'z') + or (i >= '0' and i <= '9') + or (i >= 'A' and i <= 'Z') + ): continue # Correctus raise Exception(f'TICKET INVALID (char {i} at pos {n})') @@ -68,7 +75,7 @@ class Proxy: # raise Exception(f'TICKET INVALID (check {r.json})') return { 'host': ['172.27.1.15', '172.27.0.10'][int(ticket[0]) - 0x30], - 'port': '3389' + 'port': '3389', } @staticmethod @@ -98,17 +105,16 @@ class Proxy: await source.sendall(b'FORBIDDEN') return - logger.info('STATS TO %s', address) - data = stats.GlobalStats.get_stats(self.ns) for v in data: logger.debug('SENDING %s', v) await source.sendall(v.encode() + b'\n') - async def proxy(self, source, address: typing.Tuple[str, int]) -> None: - logger.info('OPEN FROM %s', address) + pretty_adress = address[0] # Get only source IP + + logger.info('CONNECT FROM %s', pretty_adress) try: # First, ensure handshake (simple handshake) and command @@ -128,12 +134,16 @@ class Proxy: # Handshake correct, get the command (4 bytes) command: bytes = await source.recv(consts.COMMAND_LENGTH) if command == consts.COMMAND_TEST: + logger.info('COMMAND: TEST') await source.sendall(b'OK') return if command in (consts.COMMAND_STAT, consts.COMMAND_INFO): + logger.info('COMMAND: %s', command.decode()) # This is an stats requests - await self.stats(full=command==consts.COMMAND_STAT, source=source, address=address) + await self.stats( + full=command == consts.COMMAND_STAT, source=source, address=address + ) return if command != consts.COMMAND_OPEN: @@ -147,19 +157,19 @@ class Proxy: try: result = await curio.run_in_thread(Proxy.getFromUds, self.cfg, ticket) except Exception as e: - logger.error('%s', e.args[0] if e.args else e) + logger.error('ERROR %s', e.args[0] if e.args else e) raise - print(f'Result: {result}') - # Invalid result from UDS, not allowed to connect if not result: - raise Exception() + raise Exception('INVALID TICKET') + + logger.info('OPEN TUNNEL FROM %s to %s:%s', pretty_adress, result['host'], result['port']) except Exception: if consts.DEBUG: logger.exception('COMMAND') - logger.error('COMMAND from %s', address) + logger.error('ERROR from %s', address) await source.sendall(b'COMMAND_ERROR') return @@ -171,11 +181,17 @@ class Proxy: # Open remote server connection try: - destination = await curio.open_connection(result['host'], int(result['port'])) + destination = await curio.open_connection( + result['host'], int(result['port']) + ) async with curio.TaskGroup(wait=any) as grp: - await grp.spawn(Proxy.doProxy, source, destination, counter.as_sent_counter()) - await grp.spawn(Proxy.doProxy, destination, source, counter.as_recv_counter()) - logger.debug('Launched proxies') + await grp.spawn( + Proxy.doProxy, source, destination, counter.as_sent_counter() + ) + await grp.spawn( + Proxy.doProxy, destination, source, counter.as_recv_counter() + ) + logger.debug('PROXIES READY') logger.debug('Proxies finalized: %s', grp.exceptions) @@ -185,7 +201,6 @@ class Proxy: logger.error('REMOTE from %s: %s', address, e) finally: - counter.close() + counter.close() # So we ensure stats are correctly updated on ns - - logger.info('CLOSED FROM %s', address) + logger.info('TERMINATED %s', ':'.join(str(i) for i in address)) diff --git a/tunnel-server/src/udstunnel.cfg b/tunnel-server/src/udstunnel.cfg index 76b959a6..c883738d 100644 --- a/tunnel-server/src/udstunnel.cfg +++ b/tunnel-server/src/udstunnel.cfg @@ -1,7 +1,16 @@ -# Sample testing UDS tunnel configuration +# Sample DS tunnel configuration -# Log level, valid are DEBUG, INFO, WARN, ERROR +pidfile = /tmp/udstunnel.pid + +# Log level, valid are DEBUG, INFO, WARN, ERROR. Defaults to ERROR loglevel = DEBUG +# Log file, No default +logfile = /tmp/tunnel.log +# Max log size before rotating it. Defaults to 32 MB. +# The value is in MB. You can include or not the M string at end. +logsize = 20M +# Number of backup logs to keep. Defaults to 3 +lognumber = 3 # Listen address. Defaults to 0.0.0.0 address = 0.0.0.0 @@ -12,16 +21,17 @@ workers = 2 # Listening port port = 7777 -# SSL Related parameters +# SSL Related parameters. ssl_certificate = tests/testing.pem ssl_certificate_key = tests/testing.key +# ssl_ciphers and ssl_dhparam are optional. ssl_ciphers = ECDHE-RSA-AES256-GCM-SHA512:DHE-RSA-AES256-GCM-SHA512:ECDHE-RSA-AES256-GCM-SHA384:DHE-RSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-SHA384 ssl_dhparam = /etc/certs/dhparam.pem -# UDS server location +# UDS server location. https NEEDS valid certificate uds_server = http://172.27.0.1:8000 -# Secret to get access to admin commands +# Secret to get access to admin commands. No default for this. # Admin commands and only allowed from localhost # So, in order to allow this commands, ensure listen address allows connections from localhost secret = MySecret diff --git a/tunnel-server/src/udstunnel.py b/tunnel-server/src/udstunnel.py index b1e9f3a2..226c784f 100755 --- a/tunnel-server/src/udstunnel.py +++ b/tunnel-server/src/udstunnel.py @@ -29,10 +29,11 @@ ''' @author: Adolfo Gómez, dkmaster at dkmon dot com ''' +import os import sys import argparse import multiprocessing -import threading +import signal import socket import logging import typing @@ -54,28 +55,44 @@ BACKLOG = 100 logger = logging.getLogger(__name__) +do_stop = False + + +def stop_signal(signum, frame): + global do_stop + do_stop = True + logger.debug('SIGNAL %s, frame: %s', signum, frame) + def setup_log(cfg: config.ConfigurationType) -> None: - # Setup basic logging - log = logging.getLogger() - log.setLevel(logging.DEBUG) - handler = logging.StreamHandler(sys.stdout) - handler.setLevel(logging.DEBUG) - formatter = logging.Formatter( - '%(levelname)s - %(message)s' - ) # Basic log format, nice for syslog - handler.setFormatter(formatter) - log.addHandler(handler) + from logging.handlers import RotatingFileHandler # Update logging if needed if cfg.log_file: - fileh = logging.FileHandler(cfg.log_file, 'a') + fileh = RotatingFileHandler( + filename=cfg.log_file, + mode='a', + maxBytes=cfg.log_size, + backupCount=cfg.log_number, + ) formatter = logging.Formatter(consts.LOGFORMAT) fileh.setFormatter(formatter) log = logging.getLogger() - for hdlr in log.handlers[:]: - log.removeHandler(hdlr) + log.setLevel(cfg.log_level) + # for hdlr in log.handlers[:]: + # log.removeHandler(hdlr) log.addHandler(fileh) + else: + # Setup basic logging + log = logging.getLogger() + log.setLevel(cfg.log_level) + handler = logging.StreamHandler(sys.stdout) + handler.setLevel(cfg.log_level) + formatter = logging.Formatter( + '%(levelname)s - %(message)s' + ) # Basic log format, nice for syslog + handler.setFormatter(formatter) + log.addHandler(handler) async def tunnel_proc_async( @@ -109,31 +126,46 @@ async def tunnel_proc_async( context.load_dh_params(cfg.ssl_dhparam) while True: + address = ('', '') try: sock, address = await curio.run_in_thread(get_socket, pipe) if not sock: break logger.debug( - f'{multiprocessing.current_process().pid!r}: Got new connection from {address!r}' + f'CONNECTION from {address!r} (pid: {os.getpid()})' + ) + sock = await context.wrap_socket( + curio.io.Socket(sock), server_side=True ) - sock = await context.wrap_socket(curio.io.Socket(sock), server_side=True) await group.spawn(tunneler, sock, address) del sock - except Exception as e: - logger.error('SETING UP CONNECTION: %s', e) + except Exception: + logger.error('NEGOTIATION ERROR from %s', address[0]) async with curio.TaskGroup() as tg: await tg.spawn(run_server, pipe, cfg, tg) # Reap all of the children tasks as they complete async for task in tg: - logger.debug(f'Deleting {task!r}') + logger.debug(f'REMOVING async task {task!r}') task.joined = True del task def tunnel_main(): cfg = config.read() - setup_log(cfg) + + # Create pid file + try: + setup_log(cfg) + with open(cfg.pidfile, mode='w') as f: + f.write(str(os.getpid())) + except Exception as e: + sys.stderr.write(f'Tunnel startup error: {e}\n') + return + + # Setup signal handlers + signal.signal(signal.SIGINT, stop_signal) + signal.signal(signal.SIGTERM, stop_signal) # Creates as many processes and pipes as required child: typing.List[ @@ -149,6 +181,7 @@ def tunnel_main(): args=(tunnel_proc_async, child_conn, cfg, stats_collector.ns), ) task.start() + logger.debug('ADD CHILD PID: %s', task.pid) child.append((own_conn, task, psutil.Process(task.pid))) def best_child() -> 'Connection': @@ -168,24 +201,39 @@ def tunnel_main(): try: sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, True) except (AttributeError, OSError) as e: - logger.warning('socket.REUSEPORT not available', exc_info=True) + logger.warning('socket.REUSEPORT not available') + sock.settimeout(3.0) # So we can check for stop from time to time sock.bind((cfg.listen_address, cfg.listen_port)) sock.listen(BACKLOG) - while True: - client, addr = sock.accept() - # Select BEST process for sending this new connection - best_child().send(message.Message(message.Command.TUNNEL, (client, addr))) - del client # Ensure socket is controlled on child process - except Exception: - logger.exception('Mar') + while not do_stop: + try: + client, addr = sock.accept() + # Select BEST process for sending this new connection + best_child().send( + message.Message(message.Command.TUNNEL, (client, addr)) + ) + del client # Ensure socket is controlled on child process + except socket.timeout: + pass # Continue and retry + except Exception as e: + logger.error('LOOP: %s', e) + except Exception as e: + logger.error('MAIN: %s', e) pass - logger.info('Exiting tunnel server') - if sock: sock.close() + # Try to stop running childs + for i in child: + try: + i[2].kill() + except Exception as e: + logger.info('KILLING child %s: %s', i[2], e) + + logger.info('FINISHED') + def main() -> None: parser = argparse.ArgumentParser()