From 971e5984d9569c583250011ff579f9188c265570 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adolfo=20G=C3=B3mez=20Garc=C3=ADa?= Date: Wed, 13 Jan 2021 10:04:26 +0100 Subject: [PATCH] Advancing on new tunneler --- tunnel-server/src/forwarder/uds_forwarder.py | 5 +- tunnel-server/src/uds_tunnel/proxy.py | 30 +-- tunnel-server/src/uds_tunnel/stats.py | 204 +++++-------------- tunnel-server/src/udstunnel.py | 48 +++-- 4 files changed, 102 insertions(+), 185 deletions(-) diff --git a/tunnel-server/src/forwarder/uds_forwarder.py b/tunnel-server/src/forwarder/uds_forwarder.py index e2f4ea4b..27ec2e48 100644 --- a/tunnel-server/src/forwarder/uds_forwarder.py +++ b/tunnel-server/src/forwarder/uds_forwarder.py @@ -143,6 +143,7 @@ class Handler(socketserver.BaseRequestHandler): if not data: break self.request.sendall(data) + logger.debug('Finished process') except Exception as e: pass @@ -173,8 +174,8 @@ def forward( if __name__ == "__main__": fs1 = forward(('fake.udsenterprise.com', 7777), '0'*64, local_port=49998) print(f'Listening on {fs1.server_address}') - #fs2 = forward(('fake.udsenterprise.com', 7777), '1'*64, local_port=49999) - #print(f'Listening on {fs2.server_address}') + fs2 = forward(('fake.udsenterprise.com', 7777), '1'*64, local_port=49999) + print(f'Listening on {fs2.server_address}') # time.sleep(30) # fs.stop() diff --git a/tunnel-server/src/uds_tunnel/proxy.py b/tunnel-server/src/uds_tunnel/proxy.py index 244d0731..b08ed592 100644 --- a/tunnel-server/src/uds_tunnel/proxy.py +++ b/tunnel-server/src/uds_tunnel/proxy.py @@ -38,15 +38,18 @@ from . import config from . import stats from . import consts +if typing.TYPE_CHECKING: + from multiprocessing.managers import Namespace + logger = logging.getLogger(__name__) class Proxy: cfg: config.ConfigurationType - stat: stats.Stats + ns: 'Namespace' - def __init__(self, cfg: config.ConfigurationType) -> None: + def __init__(self, cfg: config.ConfigurationType, ns: 'Namespace') -> None: self.cfg = cfg - self.stat = stats.Stats() + self.ns = ns @staticmethod def getFromUds(cfg: config.ConfigurationType, ticket: bytes) -> typing.MutableMapping[str, typing.Any]: @@ -77,6 +80,10 @@ class Proxy: await destination.sendall(data) counter.add(len(data)) + # Method responsible of proxying requests + async def __call__(self, source, address: typing.Tuple[str, int]) -> None: + await self.proxy(source, address) + async def stats(self, full: bool, source, address: typing.Tuple[str, int]) -> None: # Check valid source ip if address[0] not in self.cfg.allow: @@ -93,17 +100,12 @@ class Proxy: logger.info('STATS TO %s', address) - if full: - data = self.stat.full_as_csv() - else: - data = self.stat.simple_as_csv() + data = stats.GlobalStats.get_stats(self.ns) - async for v in data: + for v in data: + logger.debug('SENDING %s', v) await source.sendall(v.encode() + b'\n') - # Method responsible of proxying requests - async def __call__(self, source, address: typing.Tuple[str, int]) -> None: - await self.proxy(source, address) async def proxy(self, source, address: typing.Tuple[str, int]) -> None: logger.info('OPEN FROM %s', address) @@ -119,7 +121,6 @@ class Proxy: logger.exception('HANDSHAKE') logger.error('HANDSHAKE from %s', address) await source.sendall(b'HANDSHAKE_ERROR') - # Closes connection now return @@ -166,7 +167,7 @@ class Proxy: await source.sendall(b'OK') # Initialize own stats counter - counter = await self.stat.new() + counter = stats.Stats(self.ns) # Open remote server connection try: @@ -184,8 +185,7 @@ class Proxy: logger.error('REMOTE from %s: %s', address, e) finally: - await counter.close() + counter.close() logger.info('CLOSED FROM %s', address) - logger.info('STATS: %s', counter.as_csv()) diff --git a/tunnel-server/src/uds_tunnel/stats.py b/tunnel-server/src/uds_tunnel/stats.py index 25bcf8fb..4b3d22b0 100644 --- a/tunnel-server/src/uds_tunnel/stats.py +++ b/tunnel-server/src/uds_tunnel/stats.py @@ -28,30 +28,30 @@ ''' @author: Adolfo Gómez, dkmaster at dkmon dot com ''' +import multiprocessing import time +import logging +import typing import io import ssl import logging import typing import curio -import blist from . import config from . import consts +if typing.TYPE_CHECKING: + from multiprocessing.managers import Namespace, SyncManager + +INTERVAL = 2 # Interval in seconds between stats update + logger = logging.getLogger(__name__) -# Locker for id assigner -assignLock = curio.Lock() - -# Tuple index for several stats -SENT, RECV = 0, 1 - -# Subclasses for += operation to work class StatsSingleCounter: - def __init__(self, parent: 'StatsConnection', for_receiving=True) -> None: + def __init__(self, parent: 'Stats', for_receiving=True) -> None: if for_receiving: self.adder = parent.add_recv else: @@ -62,56 +62,34 @@ class StatsSingleCounter: return self -class StatsConnection: - id: int - recv: int +class Stats: + ns: 'Namespace' sent: int - start_time: int - parent: 'Stats' + recv: int + last: float - # Bandwidth stats (SENT, RECV) - last: typing.List[int] - last_time: typing.List[float] + def __init__(self, ns: 'Namespace'): + self.ns = ns + self.ns.current += 1 + self.ns.total += 1 + self.sent = 0 + self.recv = 0 + self.last = time.monotonic() - bandwidth: typing.List[int] - max_bandwidth: typing.List[int] - - def __init__(self, parent: 'Stats', id: int) -> None: - self.id = id - self.recv = self.sent = 0 - - now = time.time() - self.start_time = int(now) - self.parent = parent - - self.last = [0, 0] - self.last_time = [now, now] - self.bandwidth = [0, 0] - self.max_bandwidth = [0, 0] - - def update_bandwidth(self, kind: int, counter: int): - now = time.time() - elapsed = now - self.last_time[kind] - # Update only when enouth data - if elapsed < consts.BANDWIDTH_TIME: - return - total = counter - self.last[kind] - self.bandwidth[kind] = int(float(total) / elapsed) - self.last[kind] = counter - self.last_time[kind] = now - - if self.bandwidth[kind] > self.max_bandwidth[kind]: - self.max_bandwidth[kind] = self.bandwidth[kind] + def update(self, force: bool = False): + now = time.monotonic() + if force or now - self.last > INTERVAL: + self.last = now + self.ns.recv = self.recv + self.ns.sent = self.sent def add_recv(self, size: int) -> None: self.recv += size - self.update_bandwidth(RECV, counter=self.recv) - self.parent.add_recv(size) + self.update() def add_sent(self, size: int) -> None: self.sent += size - self.update_bandwidth(SENT, counter=self.sent) - self.parent.add_sent(size) + self.update() def as_sent_counter(self) -> 'StatsSingleCounter': return StatsSingleCounter(self, False) @@ -119,114 +97,34 @@ class StatsConnection: def as_recv_counter(self) -> 'StatsSingleCounter': return StatsSingleCounter(self, True) - async def close(self) -> None: - if self.id: - logger.debug(f'STAT {self.id} closed') - await self.parent.remove(self.id) - self.id = 0 + def close(self): + self.update(True) + self.ns.current -= 1 - def as_csv(self, separator: typing.Optional[str] = None) -> str: - separator = separator or ';' - # With connections of less than a second, consider them as a second - elapsed = (int(time.time()) - self.start_time) +# Stats collector thread +class GlobalStats: + manager: 'SyncManager' + ns: 'Namespace' + counter: int - return separator.join( - str(i) - for i in ( - self.id, - self.start_time, - elapsed, - self.sent, - self.bandwidth[SENT], - self.max_bandwidth[SENT], - self.recv, - self.bandwidth[RECV], - self.max_bandwidth[RECV], - ) - ) + def __init__(self): + super().__init__() + self.manager = multiprocessing.Manager() + self.ns = self.manager.Namespace() - def __str__(self) -> str: - return f'{self.id} t:{int(time.time())-self.start_time}, r:{self.recv}, s:{self.sent}>' + # Counters + self.ns.current = 0 + self.ns.total = 0 + self.ns.sent = 0 + self.ns.recv = 0 + self.counter = 0 - # For sorted array - def __lt__(self, other) -> bool: - if isinstance(other, int): - return self.id < other - - if not isinstance(other, StatsConnection): - raise NotImplemented - - return self.id < other.id - - def __eq__(self, other) -> bool: - if isinstance(other, int): - return self.id == other - - if not isinstance(other, StatsConnection): - raise NotImplemented - - return self.id == other.id - - -class Stats: - counter_id: int - - total_sent: int - total_received: int - current_connections: blist.sortedlist - - def __init__(self) -> None: - # First connection will be 1 - self.counter_id = 0 - self.total_sent = self.total_received = 0 - self.current_connections = blist.sortedlist() - - async def new(self) -> StatsConnection: - """Initializes a connection stats counter and returns it id - - Returns: - str: connection id - """ - async with assignLock: - self.counter_id += 1 - connection = StatsConnection(self, self.counter_id) - self.current_connections.add(connection) - return connection - - def add_sent(self, size: int) -> None: - self.total_sent += size - - def add_recv(self, size: int) -> None: - self.total_received += size - - async def remove(self, connection_id: int) -> None: - async with assignLock: - try: - self.current_connections.remove(connection_id) - except Exception: - logger.debug( - 'Tried to remove %s from connections but was not present', - connection_id, - ) - # Does not exists, ignore it - pass - - async def simple_as_csv(self, separator: typing.Optional[str] = None) -> typing.AsyncIterable[str]: - separator = separator or ';' - yield separator.join( - str(i) - for i in ( - self.counter_id, - self.total_sent, - self.total_received, - len(self.current_connections), - ) - ) - - async def full_as_csv(self, separator: typing.Optional[str] = None) -> typing.AsyncIterable[str]: - for i in self.current_connections: - yield i.as_csv(separator) + def info(self) -> typing.Iterable[str]: + return GlobalStats.get_stats(self.ns) + @staticmethod + def get_stats(ns: 'Namespace') -> typing.Iterable[str]: + yield ';'.join([str(ns.current), str(ns.total), str(ns.sent), str(ns.recv)]) # Stats processor, invoked from command line async def getServerStats(detailed: bool = False) -> None: diff --git a/tunnel-server/src/udstunnel.py b/tunnel-server/src/udstunnel.py index b2d52e23..b1e9f3a2 100755 --- a/tunnel-server/src/udstunnel.py +++ b/tunnel-server/src/udstunnel.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- # -# Copyright (c) 2020 Virtual Cable S.L.U. +# Copyright (c) 2021 Virtual Cable S.L.U. # All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, @@ -32,6 +32,7 @@ import sys import argparse import multiprocessing +import threading import socket import logging import typing @@ -47,6 +48,7 @@ from uds_tunnel import stats if typing.TYPE_CHECKING: from multiprocessing.connection import Connection + from multiprocessing.managers import Namespace BACKLOG = 100 @@ -76,7 +78,9 @@ def setup_log(cfg: config.ConfigurationType) -> None: log.addHandler(fileh) -async def tunnel_proc_async(pipe: 'Connection', cfg: config.ConfigurationType) -> None: +async def tunnel_proc_async( + pipe: 'Connection', cfg: config.ConfigurationType, ns: 'Namespace' +) -> None: def get_socket(pipe: 'Connection') -> typing.Tuple[socket.SocketType, typing.Any]: try: while True: @@ -92,7 +96,7 @@ async def tunnel_proc_async(pipe: 'Connection', cfg: config.ConfigurationType) - pipe: 'Connection', cfg: config.ConfigurationType, group: curio.TaskGroup ) -> None: # Instantiate a proxy redirector for this process (we only need one per process!!) - tunneler = proxy.Proxy(cfg) + tunneler = proxy.Proxy(cfg, ns) # Generate SSL context context = curio.ssl.SSLContext(curio.ssl.PROTOCOL_TLS_SERVER) @@ -105,15 +109,18 @@ async def tunnel_proc_async(pipe: 'Connection', cfg: config.ConfigurationType) - context.load_dh_params(cfg.ssl_dhparam) while True: - sock, address = await curio.run_in_thread(get_socket, pipe) - if not sock: - break - logger.debug( - f'{multiprocessing.current_process().pid!r}: Got new connection from {address!r}' - ) - sock = await context.wrap_socket(curio.io.Socket(sock), server_side=True) - await group.spawn(tunneler, sock, address) - del sock + try: + sock, address = await curio.run_in_thread(get_socket, pipe) + if not sock: + break + logger.debug( + f'{multiprocessing.current_process().pid!r}: Got new connection from {address!r}' + ) + sock = await context.wrap_socket(curio.io.Socket(sock), server_side=True) + await group.spawn(tunneler, sock, address) + del sock + except Exception as e: + logger.error('SETING UP CONNECTION: %s', e) async with curio.TaskGroup() as tg: await tg.spawn(run_server, pipe, cfg, tg) @@ -133,10 +140,13 @@ def tunnel_main(): typing.Tuple['Connection', multiprocessing.Process, psutil.Process] ] = [] + stats_collector = stats.GlobalStats() + for i in range(cfg.workers): own_conn, child_conn = multiprocessing.Pipe() task = multiprocessing.Process( - target=curio.run, args=(tunnel_proc_async, child_conn, cfg) + target=curio.run, + args=(tunnel_proc_async, child_conn, cfg, stats_collector.ns), ) task.start() child.append((own_conn, task, psutil.Process(task.pid))) @@ -166,10 +176,13 @@ def tunnel_main(): client, addr = sock.accept() # Select BEST process for sending this new connection best_child().send(message.Message(message.Command.TUNNEL, (client, addr))) - + del client # Ensure socket is controlled on child process except Exception: + logger.exception('Mar') pass + logger.info('Exiting tunnel server') + if sock: sock.close() @@ -196,7 +209,12 @@ def main() -> None: if args.tunnel: tunnel_main() - parser.print_help() + elif args.detailed_stats: + curio.run(stats.getServerStats, True) + elif args.stats: + curio.run(stats.getServerStats, False) + else: + parser.print_help() if __name__ == "__main__":