From 7d8ae689b56e0652a3c44b8626dee281c1b38204 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adolfo=20G=C3=B3mez=20Garc=C3=ADa?= Date: Wed, 26 Jan 2022 12:18:41 +0100 Subject: [PATCH] Moving from curio to asyncio. --- tunnel-server/requirements.txt | 1 - tunnel-server/src/uds_tunnel/message.py | 45 ---------- tunnel-server/src/uds_tunnel/processes.py | 12 +-- tunnel-server/src/uds_tunnel/proxy.py | 55 ++++++------ tunnel-server/src/uds_tunnel/stats.py | 28 ++++-- tunnel-server/src/udstunnel.py | 103 ++++++++++++---------- 6 files changed, 109 insertions(+), 135 deletions(-) delete mode 100644 tunnel-server/src/uds_tunnel/message.py diff --git a/tunnel-server/requirements.txt b/tunnel-server/requirements.txt index d2268fa6..e20e7482 100644 --- a/tunnel-server/requirements.txt +++ b/tunnel-server/requirements.txt @@ -1,2 +1 @@ -curio>=1.4 psutil>=5.7.3 diff --git a/tunnel-server/src/uds_tunnel/message.py b/tunnel-server/src/uds_tunnel/message.py deleted file mode 100644 index 4a5fabba..00000000 --- a/tunnel-server/src/uds_tunnel/message.py +++ /dev/null @@ -1,45 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Copyright (c) 2021 Virtual Cable S.L.U. -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without modification, -# are permitted provided that the following conditions are met: -# -# * Redistributions of source code must retain the above copyright notice, -# this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# * Neither the name of Virtual Cable S.L. nor the names of its contributors -# may be used to endorse or promote products derived from this software -# without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -''' -@author: Adolfo Gómez, dkmaster at dkmon dot com -''' -import enum -import socket -import typing - -class Command(enum.IntEnum): - TUNNEL = 0 - STATS = 1 - -class Message: - command: Command - connection: typing.Optional[typing.Tuple[socket.socket, typing.Any]] - - def __init__(self, command: Command, connection: typing.Optional[typing.Tuple[socket.socket, typing.Any]]): - self.command = command - self.connection = connection diff --git a/tunnel-server/src/uds_tunnel/processes.py b/tunnel-server/src/uds_tunnel/processes.py index d2db355e..e62267fb 100644 --- a/tunnel-server/src/uds_tunnel/processes.py +++ b/tunnel-server/src/uds_tunnel/processes.py @@ -1,8 +1,8 @@ import multiprocessing +import asyncio import logging import typing -import curio import psutil from . import config @@ -13,6 +13,8 @@ if typing.TYPE_CHECKING: logger = logging.getLogger(__name__) +ProcessType = typing.Callable[['Connection', config.ConfigurationType, 'Namespace'], typing.Coroutine[typing.Any, None, None]] + class Processes: """ This class is used to store the processes that are used by the tunnel. @@ -21,11 +23,11 @@ class Processes: children: typing.List[ typing.Tuple['Connection', multiprocessing.Process, psutil.Process] ] - process: typing.Callable + process: ProcessType cfg: config.ConfigurationType ns: 'Namespace' - def __init__(self, process: typing.Callable, cfg: config.ConfigurationType, ns: 'Namespace') -> None: + def __init__(self, process: ProcessType, cfg: config.ConfigurationType, ns: 'Namespace') -> None: self.children = [] self.process = process # type: ignore self.cfg = cfg @@ -37,8 +39,8 @@ class Processes: def add_child_pid(self): own_conn, child_conn = multiprocessing.Pipe() task = multiprocessing.Process( - target=curio.run, - args=(self.process, child_conn, self.cfg, self.ns) + target=asyncio.run, + args=(self.process(child_conn, self.cfg, self.ns),) ) task.start() logger.debug('ADD CHILD PID: %s', task.pid) diff --git a/tunnel-server/src/uds_tunnel/proxy.py b/tunnel-server/src/uds_tunnel/proxy.py index 2fefa72c..63c06af4 100644 --- a/tunnel-server/src/uds_tunnel/proxy.py +++ b/tunnel-server/src/uds_tunnel/proxy.py @@ -28,10 +28,11 @@ ''' @author: Adolfo Gómez, dkmaster at dkmon dot com ''' +import asyncio +import socket import logging import typing -import curio import requests from . import config @@ -40,7 +41,7 @@ from . import consts if typing.TYPE_CHECKING: from multiprocessing.managers import Namespace - import curio.io + import ssl logger = logging.getLogger(__name__) @@ -109,29 +110,29 @@ class Proxy: cfg, ticket, 'stop', {'sent': str(counter.sent), 'recv': str(counter.recv)} ) # Ignore results - @staticmethod - async def doProxy( - source: 'curio.io.Socket', - destination: 'curio.io.Socket', - counter: stats.StatsSingleCounter, - ) -> None: - try: - while True: - data = await source.recv(consts.BUFFER_SIZE) - if not data: - break - await destination.sendall(data) - counter.add(len(data)) - except Exception: - # Connection broken, same result as closed for us - # We must notice that i'ts easy that when closing one part of the tunnel, - # the other can break (due to some internal data), that's why even log is removed - # logger.info('CONNECTION LOST FROM %s to %s', source.getsockname(), destination.getpeername()) - pass + # @staticmethod + # async def doProxy( + # source: 'curio.io.Socket', + # destination: 'curio.io.Socket', + # counter: stats.StatsSingleCounter, + # ) -> None: + # try: + # while True: + # data = await source.recv(consts.BUFFER_SIZE) + # if not data: + # break + # await destination.sendall(data) + # counter.add(len(data)) + # except Exception: + # # Connection broken, same result as closed for us + # # We must notice that i'ts easy that when closing one part of the tunnel, + # # the other can break (due to some internal data), that's why even log is removed + # # logger.info('CONNECTION LOST FROM %s to %s', source.getsockname(), destination.getpeername()) + # pass # Method responsible of proxying requests - async def __call__(self, source, address: typing.Tuple[str, int]) -> None: - await self.proxy(source, address) + async def __call__(self, source: socket.socket, address: typing.Tuple[str, int], context: 'ssl.SSLContext') -> None: + await self.proxy(source, address, context) async def stats(self, full: bool, source, address: typing.Tuple[str, int]) -> None: # Check valid source ip @@ -153,17 +154,19 @@ class Proxy: logger.debug('SENDING %s', v) await source.sendall(v.encode() + b'\n') - async def proxy(self, source, address: typing.Tuple[str, int]) -> None: + async def proxy(self, source: socket.socket, address: typing.Tuple[str, int], context: 'ssl.SSLContext') -> None: prettySource = address[0] # Get only source IP prettyDest = '' logger.info('CONNECT FROM %s', prettySource) + loop = asyncio.get_event_loop() + # Handshake correct in this point, start SSL connection try: - command: bytes = await source.recv(consts.COMMAND_LENGTH) + command: bytes = await loop.sock_recv(source, consts.COMMAND_LENGTH) if command == consts.COMMAND_TEST: logger.info('COMMAND: TEST') - await source.sendall(b'OK') + await loop.sock_sendall(source, b'OK') logger.info('TERMINATED %s', prettySource) return diff --git a/tunnel-server/src/uds_tunnel/stats.py b/tunnel-server/src/uds_tunnel/stats.py index a0530ed2..b35e42c2 100644 --- a/tunnel-server/src/uds_tunnel/stats.py +++ b/tunnel-server/src/uds_tunnel/stats.py @@ -29,15 +29,16 @@ @author: Adolfo Gómez, dkmaster at dkmon dot com ''' import multiprocessing +import socket import time import logging import typing import io +import asyncio import ssl import logging import typing -import curio from . import config from . import consts @@ -146,15 +147,24 @@ async def getServerStats(detailed: bool = False) -> None: try: host = cfg.listen_address if cfg.listen_address != '0.0.0.0' else 'localhost' - sock = await curio.open_connection( - host, cfg.listen_port, ssl=context, server_hostname='localhost' - ) - tmpdata = io.BytesIO() - cmd = consts.COMMAND_STAT if detailed else consts.COMMAND_INFO - async with sock: - await sock.sendall(consts.HANDSHAKE_V1 + cmd + cfg.secret.encode()) + reader: asyncio.StreamReader + writer: asyncio.StreamWriter + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.connect((host, cfg.listen_port)) + # Send HANDSHAKE + sock.sendall(consts.HANDSHAKE_V1) + # Ugrade connection to TLS + reader, writer = await asyncio.open_connection(sock=sock, ssl=context, server_hostname=host) + + tmpdata = io.BytesIO() + cmd = consts.COMMAND_STAT if detailed else consts.COMMAND_INFO + + writer.write(cmd + cfg.secret.encode()) + await writer.drain() + while True: - chunk = await sock.recv(consts.BUFFER_SIZE) + chunk = await reader.read(consts.BUFFER_SIZE) if not chunk: break tmpdata.write(chunk) diff --git a/tunnel-server/src/udstunnel.py b/tunnel-server/src/udstunnel.py index 61496605..54060a3d 100755 --- a/tunnel-server/src/udstunnel.py +++ b/tunnel-server/src/udstunnel.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- # -# Copyright (c) 2021 Virtual Cable S.L.U. +# Copyright (c) 2021-2022 Virtual Cable S.L.U. # All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, @@ -32,22 +32,18 @@ import os import pwd import sys +import asyncio import argparse import signal +import ssl import socket import logging import typing -import curio import setproctitle -from uds_tunnel import config -from uds_tunnel import proxy -from uds_tunnel import consts -from uds_tunnel import message -from uds_tunnel import stats -from uds_tunnel import processes +from uds_tunnel import config, proxy, consts, processes, stats if typing.TYPE_CHECKING: from multiprocessing.connection import Connection @@ -100,17 +96,34 @@ def setup_log(cfg: config.ConfigurationType) -> None: async def tunnel_proc_async( pipe: 'Connection', cfg: config.ConfigurationType, ns: 'Namespace' ) -> None: - def get_socket(pipe: 'Connection') -> typing.Tuple[typing.Optional[socket.SocketType], typing.Any]: + + loop = asyncio.get_event_loop() + # Create event for flagging when we have new data + event = asyncio.Event() + loop.add_reader(pipe.fileno(), event.set) + + tasks: typing.List[asyncio.Task] = [] + + async def get_socket() -> typing.Tuple[ + typing.Optional[socket.socket], typing.Tuple[str, int] + ]: try: while True: - msg: message.Message = pipe.recv() - if msg.command == message.Command.TUNNEL and msg.connection: + await event.wait() + # Clear back event, for next data + event.clear() + msg: typing.Optional[ + typing.Tuple[socket.socket, typing.Tuple[str, int]] + ] = pipe.recv() + if msg: # Connection done, check for handshake - source, address = msg.connection + source, address = msg try: # First, ensure handshake (simple handshake) and command - data: bytes = source.recv(len(consts.HANDSHAKE_V1)) + data: bytes = await loop.sock_recv( + source, len(consts.HANDSHAKE_V1) + ) if data != consts.HANDSHAKE_V1: raise Exception() # Invalid handshake @@ -122,21 +135,19 @@ async def tunnel_proc_async( source.close() continue - return msg.connection + return msg # Process other messages, and retry except Exception: logger.exception('Receiving data from parent process') - return None, None + return None, ('', 0) - async def run_server( - pipe: 'Connection', cfg: config.ConfigurationType, group: curio.TaskGroup - ) -> None: + async def run_server() -> None: # Instantiate a proxy redirector for this process (we only need one per process!!) tunneler = proxy.Proxy(cfg, ns) # Generate SSL context - context = curio.ssl.SSLContext(curio.ssl.PROTOCOL_TLS_SERVER) # type: ignore + context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) context.load_cert_chain(cfg.ssl_certificate, cfg.ssl_certificate_key) if cfg.ssl_ciphers: @@ -146,30 +157,27 @@ async def tunnel_proc_async( context.load_dh_params(cfg.ssl_dhparam) while True: - address = ('', '') + address: typing.Tuple[str, int] = ('', 0) try: - sock, address = await curio.run_in_thread(get_socket, pipe) + sock, address = await get_socket() if not sock: - break - logger.debug( - f'CONNECTION from {address!r} (pid: {os.getpid()})' - ) - sock = await context.wrap_socket( - curio.io.Socket(sock), server_side=True # type: ignore - ) - await group.spawn(tunneler, sock, address) - del sock + break # No more sockets, exit + logger.debug(f'CONNECTION from {address!r} (pid: {os.getpid()})') + tasks.append(asyncio.create_task(tunneler(sock, address, context))) except Exception: logger.error('NEGOTIATION ERROR from %s', address[0]) - async with curio.TaskGroup() as tg: - await tg.spawn(run_server, pipe, cfg, tg) - await tg.join() - # Reap all of the children tasks as they complete - # async for task in tg: - # logger.debug(f'REMOVING async task {task!r}') - # task.joined = True - # del task + # create task for server + tasks.append(asyncio.create_task(run_server())) + + while tasks: + tasks_number = len(tasks) + await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED) + # Remove finished tasks from list + del tasks[:tasks_number] + + # Remove reader from event loop + loop.remove_reader(pipe.fileno()) def tunnel_main(): @@ -200,7 +208,9 @@ def tunnel_main(): setup_log(cfg) - logger.info('Starting tunnel server on %s:%s', cfg.listen_address, cfg.listen_port) + logger.info( + 'Starting tunnel server on %s:%s', cfg.listen_address, cfg.listen_port + ) setproctitle.setproctitle(f'UDSTunnel {cfg.listen_address}:{cfg.listen_port}') # Create pid file @@ -213,7 +223,6 @@ def tunnel_main(): logger.error('MAIN: %s', e) return - # Setup signal handlers signal.signal(signal.SIGINT, stop_signal) signal.signal(signal.SIGTERM, stop_signal) @@ -227,9 +236,7 @@ def tunnel_main(): try: client, addr = sock.accept() # Select BEST process for sending this new connection - prcs.best_child().send( - message.Message(message.Command.TUNNEL, (client, addr)) - ) + prcs.best_child().send((client, addr)) del client # Ensure socket is controlled on child process except socket.timeout: pass # Continue and retry @@ -259,9 +266,7 @@ def main() -> None: group.add_argument( '-t', '--tunnel', help='Starts the tunnel server', action='store_true' ) - group.add_argument( - '-r', '--rdp', help='RDP Tunnel for traffic accounting' - ) + group.add_argument('-r', '--rdp', help='RDP Tunnel for traffic accounting') group.add_argument( '-s', '--stats', @@ -281,12 +286,12 @@ def main() -> None: elif args.rdp: pass elif args.detailed_stats: - curio.run(stats.getServerStats, True) + asyncio.run(stats.getServerStats(True)) elif args.stats: - curio.run(stats.getServerStats, False) + asyncio.run(stats.getServerStats(False)) else: parser.print_help() if __name__ == "__main__": - main() \ No newline at end of file + main()