Moving from curio to asyncio.

This commit is contained in:
Adolfo Gómez García 2022-01-26 12:18:41 +01:00
parent 143b9b675b
commit 7d8ae689b5
6 changed files with 109 additions and 135 deletions

View File

@ -1,2 +1 @@
curio>=1.4
psutil>=5.7.3 psutil>=5.7.3

View File

@ -1,45 +0,0 @@
# -*- coding: utf-8 -*-
#
# Copyright (c) 2021 Virtual Cable S.L.U.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of Virtual Cable S.L. nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
'''
@author: Adolfo Gómez, dkmaster at dkmon dot com
'''
import enum
import socket
import typing
class Command(enum.IntEnum):
TUNNEL = 0
STATS = 1
class Message:
command: Command
connection: typing.Optional[typing.Tuple[socket.socket, typing.Any]]
def __init__(self, command: Command, connection: typing.Optional[typing.Tuple[socket.socket, typing.Any]]):
self.command = command
self.connection = connection

View File

@ -1,8 +1,8 @@
import multiprocessing import multiprocessing
import asyncio
import logging import logging
import typing import typing
import curio
import psutil import psutil
from . import config from . import config
@ -13,6 +13,8 @@ if typing.TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ProcessType = typing.Callable[['Connection', config.ConfigurationType, 'Namespace'], typing.Coroutine[typing.Any, None, None]]
class Processes: class Processes:
""" """
This class is used to store the processes that are used by the tunnel. This class is used to store the processes that are used by the tunnel.
@ -21,11 +23,11 @@ class Processes:
children: typing.List[ children: typing.List[
typing.Tuple['Connection', multiprocessing.Process, psutil.Process] typing.Tuple['Connection', multiprocessing.Process, psutil.Process]
] ]
process: typing.Callable process: ProcessType
cfg: config.ConfigurationType cfg: config.ConfigurationType
ns: 'Namespace' ns: 'Namespace'
def __init__(self, process: typing.Callable, cfg: config.ConfigurationType, ns: 'Namespace') -> None: def __init__(self, process: ProcessType, cfg: config.ConfigurationType, ns: 'Namespace') -> None:
self.children = [] self.children = []
self.process = process # type: ignore self.process = process # type: ignore
self.cfg = cfg self.cfg = cfg
@ -37,8 +39,8 @@ class Processes:
def add_child_pid(self): def add_child_pid(self):
own_conn, child_conn = multiprocessing.Pipe() own_conn, child_conn = multiprocessing.Pipe()
task = multiprocessing.Process( task = multiprocessing.Process(
target=curio.run, target=asyncio.run,
args=(self.process, child_conn, self.cfg, self.ns) args=(self.process(child_conn, self.cfg, self.ns),)
) )
task.start() task.start()
logger.debug('ADD CHILD PID: %s', task.pid) logger.debug('ADD CHILD PID: %s', task.pid)

View File

@ -28,10 +28,11 @@
''' '''
@author: Adolfo Gómez, dkmaster at dkmon dot com @author: Adolfo Gómez, dkmaster at dkmon dot com
''' '''
import asyncio
import socket
import logging import logging
import typing import typing
import curio
import requests import requests
from . import config from . import config
@ -40,7 +41,7 @@ from . import consts
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from multiprocessing.managers import Namespace from multiprocessing.managers import Namespace
import curio.io import ssl
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -109,29 +110,29 @@ class Proxy:
cfg, ticket, 'stop', {'sent': str(counter.sent), 'recv': str(counter.recv)} cfg, ticket, 'stop', {'sent': str(counter.sent), 'recv': str(counter.recv)}
) # Ignore results ) # Ignore results
@staticmethod # @staticmethod
async def doProxy( # async def doProxy(
source: 'curio.io.Socket', # source: 'curio.io.Socket',
destination: 'curio.io.Socket', # destination: 'curio.io.Socket',
counter: stats.StatsSingleCounter, # counter: stats.StatsSingleCounter,
) -> None: # ) -> None:
try: # try:
while True: # while True:
data = await source.recv(consts.BUFFER_SIZE) # data = await source.recv(consts.BUFFER_SIZE)
if not data: # if not data:
break # break
await destination.sendall(data) # await destination.sendall(data)
counter.add(len(data)) # counter.add(len(data))
except Exception: # except Exception:
# Connection broken, same result as closed for us # # Connection broken, same result as closed for us
# We must notice that i'ts easy that when closing one part of the tunnel, # # We must notice that i'ts easy that when closing one part of the tunnel,
# the other can break (due to some internal data), that's why even log is removed # # the other can break (due to some internal data), that's why even log is removed
# logger.info('CONNECTION LOST FROM %s to %s', source.getsockname(), destination.getpeername()) # # logger.info('CONNECTION LOST FROM %s to %s', source.getsockname(), destination.getpeername())
pass # pass
# Method responsible of proxying requests # Method responsible of proxying requests
async def __call__(self, source, address: typing.Tuple[str, int]) -> None: async def __call__(self, source: socket.socket, address: typing.Tuple[str, int], context: 'ssl.SSLContext') -> None:
await self.proxy(source, address) await self.proxy(source, address, context)
async def stats(self, full: bool, source, address: typing.Tuple[str, int]) -> None: async def stats(self, full: bool, source, address: typing.Tuple[str, int]) -> None:
# Check valid source ip # Check valid source ip
@ -153,17 +154,19 @@ class Proxy:
logger.debug('SENDING %s', v) logger.debug('SENDING %s', v)
await source.sendall(v.encode() + b'\n') await source.sendall(v.encode() + b'\n')
async def proxy(self, source, address: typing.Tuple[str, int]) -> None: async def proxy(self, source: socket.socket, address: typing.Tuple[str, int], context: 'ssl.SSLContext') -> None:
prettySource = address[0] # Get only source IP prettySource = address[0] # Get only source IP
prettyDest = '' prettyDest = ''
logger.info('CONNECT FROM %s', prettySource) logger.info('CONNECT FROM %s', prettySource)
loop = asyncio.get_event_loop()
# Handshake correct in this point, start SSL connection # Handshake correct in this point, start SSL connection
try: try:
command: bytes = await source.recv(consts.COMMAND_LENGTH) command: bytes = await loop.sock_recv(source, consts.COMMAND_LENGTH)
if command == consts.COMMAND_TEST: if command == consts.COMMAND_TEST:
logger.info('COMMAND: TEST') logger.info('COMMAND: TEST')
await source.sendall(b'OK') await loop.sock_sendall(source, b'OK')
logger.info('TERMINATED %s', prettySource) logger.info('TERMINATED %s', prettySource)
return return

View File

@ -29,15 +29,16 @@
@author: Adolfo Gómez, dkmaster at dkmon dot com @author: Adolfo Gómez, dkmaster at dkmon dot com
''' '''
import multiprocessing import multiprocessing
import socket
import time import time
import logging import logging
import typing import typing
import io import io
import asyncio
import ssl import ssl
import logging import logging
import typing import typing
import curio
from . import config from . import config
from . import consts from . import consts
@ -146,15 +147,24 @@ async def getServerStats(detailed: bool = False) -> None:
try: try:
host = cfg.listen_address if cfg.listen_address != '0.0.0.0' else 'localhost' host = cfg.listen_address if cfg.listen_address != '0.0.0.0' else 'localhost'
sock = await curio.open_connection( reader: asyncio.StreamReader
host, cfg.listen_port, ssl=context, server_hostname='localhost' writer: asyncio.StreamWriter
)
tmpdata = io.BytesIO() with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
cmd = consts.COMMAND_STAT if detailed else consts.COMMAND_INFO sock.connect((host, cfg.listen_port))
async with sock: # Send HANDSHAKE
await sock.sendall(consts.HANDSHAKE_V1 + cmd + cfg.secret.encode()) sock.sendall(consts.HANDSHAKE_V1)
# Ugrade connection to TLS
reader, writer = await asyncio.open_connection(sock=sock, ssl=context, server_hostname=host)
tmpdata = io.BytesIO()
cmd = consts.COMMAND_STAT if detailed else consts.COMMAND_INFO
writer.write(cmd + cfg.secret.encode())
await writer.drain()
while True: while True:
chunk = await sock.recv(consts.BUFFER_SIZE) chunk = await reader.read(consts.BUFFER_SIZE)
if not chunk: if not chunk:
break break
tmpdata.write(chunk) tmpdata.write(chunk)

View File

@ -1,7 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
# Copyright (c) 2021 Virtual Cable S.L.U. # Copyright (c) 2021-2022 Virtual Cable S.L.U.
# All rights reserved. # All rights reserved.
# #
# Redistribution and use in source and binary forms, with or without modification, # Redistribution and use in source and binary forms, with or without modification,
@ -32,22 +32,18 @@
import os import os
import pwd import pwd
import sys import sys
import asyncio
import argparse import argparse
import signal import signal
import ssl
import socket import socket
import logging import logging
import typing import typing
import curio
import setproctitle import setproctitle
from uds_tunnel import config from uds_tunnel import config, proxy, consts, processes, stats
from uds_tunnel import proxy
from uds_tunnel import consts
from uds_tunnel import message
from uds_tunnel import stats
from uds_tunnel import processes
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from multiprocessing.connection import Connection from multiprocessing.connection import Connection
@ -100,17 +96,34 @@ def setup_log(cfg: config.ConfigurationType) -> None:
async def tunnel_proc_async( async def tunnel_proc_async(
pipe: 'Connection', cfg: config.ConfigurationType, ns: 'Namespace' pipe: 'Connection', cfg: config.ConfigurationType, ns: 'Namespace'
) -> None: ) -> None:
def get_socket(pipe: 'Connection') -> typing.Tuple[typing.Optional[socket.SocketType], typing.Any]:
loop = asyncio.get_event_loop()
# Create event for flagging when we have new data
event = asyncio.Event()
loop.add_reader(pipe.fileno(), event.set)
tasks: typing.List[asyncio.Task] = []
async def get_socket() -> typing.Tuple[
typing.Optional[socket.socket], typing.Tuple[str, int]
]:
try: try:
while True: while True:
msg: message.Message = pipe.recv() await event.wait()
if msg.command == message.Command.TUNNEL and msg.connection: # Clear back event, for next data
event.clear()
msg: typing.Optional[
typing.Tuple[socket.socket, typing.Tuple[str, int]]
] = pipe.recv()
if msg:
# Connection done, check for handshake # Connection done, check for handshake
source, address = msg.connection source, address = msg
try: try:
# First, ensure handshake (simple handshake) and command # First, ensure handshake (simple handshake) and command
data: bytes = source.recv(len(consts.HANDSHAKE_V1)) data: bytes = await loop.sock_recv(
source, len(consts.HANDSHAKE_V1)
)
if data != consts.HANDSHAKE_V1: if data != consts.HANDSHAKE_V1:
raise Exception() # Invalid handshake raise Exception() # Invalid handshake
@ -122,21 +135,19 @@ async def tunnel_proc_async(
source.close() source.close()
continue continue
return msg.connection return msg
# Process other messages, and retry # Process other messages, and retry
except Exception: except Exception:
logger.exception('Receiving data from parent process') logger.exception('Receiving data from parent process')
return None, None return None, ('', 0)
async def run_server( async def run_server() -> None:
pipe: 'Connection', cfg: config.ConfigurationType, group: curio.TaskGroup
) -> None:
# Instantiate a proxy redirector for this process (we only need one per process!!) # Instantiate a proxy redirector for this process (we only need one per process!!)
tunneler = proxy.Proxy(cfg, ns) tunneler = proxy.Proxy(cfg, ns)
# Generate SSL context # Generate SSL context
context = curio.ssl.SSLContext(curio.ssl.PROTOCOL_TLS_SERVER) # type: ignore context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
context.load_cert_chain(cfg.ssl_certificate, cfg.ssl_certificate_key) context.load_cert_chain(cfg.ssl_certificate, cfg.ssl_certificate_key)
if cfg.ssl_ciphers: if cfg.ssl_ciphers:
@ -146,30 +157,27 @@ async def tunnel_proc_async(
context.load_dh_params(cfg.ssl_dhparam) context.load_dh_params(cfg.ssl_dhparam)
while True: while True:
address = ('', '') address: typing.Tuple[str, int] = ('', 0)
try: try:
sock, address = await curio.run_in_thread(get_socket, pipe) sock, address = await get_socket()
if not sock: if not sock:
break break # No more sockets, exit
logger.debug( logger.debug(f'CONNECTION from {address!r} (pid: {os.getpid()})')
f'CONNECTION from {address!r} (pid: {os.getpid()})' tasks.append(asyncio.create_task(tunneler(sock, address, context)))
)
sock = await context.wrap_socket(
curio.io.Socket(sock), server_side=True # type: ignore
)
await group.spawn(tunneler, sock, address)
del sock
except Exception: except Exception:
logger.error('NEGOTIATION ERROR from %s', address[0]) logger.error('NEGOTIATION ERROR from %s', address[0])
async with curio.TaskGroup() as tg: # create task for server
await tg.spawn(run_server, pipe, cfg, tg) tasks.append(asyncio.create_task(run_server()))
await tg.join()
# Reap all of the children tasks as they complete while tasks:
# async for task in tg: tasks_number = len(tasks)
# logger.debug(f'REMOVING async task {task!r}') await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED)
# task.joined = True # Remove finished tasks from list
# del task del tasks[:tasks_number]
# Remove reader from event loop
loop.remove_reader(pipe.fileno())
def tunnel_main(): def tunnel_main():
@ -200,7 +208,9 @@ def tunnel_main():
setup_log(cfg) setup_log(cfg)
logger.info('Starting tunnel server on %s:%s', cfg.listen_address, cfg.listen_port) logger.info(
'Starting tunnel server on %s:%s', cfg.listen_address, cfg.listen_port
)
setproctitle.setproctitle(f'UDSTunnel {cfg.listen_address}:{cfg.listen_port}') setproctitle.setproctitle(f'UDSTunnel {cfg.listen_address}:{cfg.listen_port}')
# Create pid file # Create pid file
@ -213,7 +223,6 @@ def tunnel_main():
logger.error('MAIN: %s', e) logger.error('MAIN: %s', e)
return return
# Setup signal handlers # Setup signal handlers
signal.signal(signal.SIGINT, stop_signal) signal.signal(signal.SIGINT, stop_signal)
signal.signal(signal.SIGTERM, stop_signal) signal.signal(signal.SIGTERM, stop_signal)
@ -227,9 +236,7 @@ def tunnel_main():
try: try:
client, addr = sock.accept() client, addr = sock.accept()
# Select BEST process for sending this new connection # Select BEST process for sending this new connection
prcs.best_child().send( prcs.best_child().send((client, addr))
message.Message(message.Command.TUNNEL, (client, addr))
)
del client # Ensure socket is controlled on child process del client # Ensure socket is controlled on child process
except socket.timeout: except socket.timeout:
pass # Continue and retry pass # Continue and retry
@ -259,9 +266,7 @@ def main() -> None:
group.add_argument( group.add_argument(
'-t', '--tunnel', help='Starts the tunnel server', action='store_true' '-t', '--tunnel', help='Starts the tunnel server', action='store_true'
) )
group.add_argument( group.add_argument('-r', '--rdp', help='RDP Tunnel for traffic accounting')
'-r', '--rdp', help='RDP Tunnel for traffic accounting'
)
group.add_argument( group.add_argument(
'-s', '-s',
'--stats', '--stats',
@ -281,12 +286,12 @@ def main() -> None:
elif args.rdp: elif args.rdp:
pass pass
elif args.detailed_stats: elif args.detailed_stats:
curio.run(stats.getServerStats, True) asyncio.run(stats.getServerStats(True))
elif args.stats: elif args.stats:
curio.run(stats.getServerStats, False) asyncio.run(stats.getServerStats(False))
else: else:
parser.print_help() parser.print_help()
if __name__ == "__main__": if __name__ == "__main__":
main() main()