Moving from curio to asyncio.

This commit is contained in:
Adolfo Gómez García 2022-01-26 12:18:41 +01:00
parent 143b9b675b
commit 7d8ae689b5
6 changed files with 109 additions and 135 deletions

View File

@ -1,2 +1 @@
curio>=1.4
psutil>=5.7.3

View File

@ -1,45 +0,0 @@
# -*- coding: utf-8 -*-
#
# Copyright (c) 2021 Virtual Cable S.L.U.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of Virtual Cable S.L. nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
'''
@author: Adolfo Gómez, dkmaster at dkmon dot com
'''
import enum
import socket
import typing
class Command(enum.IntEnum):
TUNNEL = 0
STATS = 1
class Message:
command: Command
connection: typing.Optional[typing.Tuple[socket.socket, typing.Any]]
def __init__(self, command: Command, connection: typing.Optional[typing.Tuple[socket.socket, typing.Any]]):
self.command = command
self.connection = connection

View File

@ -1,8 +1,8 @@
import multiprocessing
import asyncio
import logging
import typing
import curio
import psutil
from . import config
@ -13,6 +13,8 @@ if typing.TYPE_CHECKING:
logger = logging.getLogger(__name__)
ProcessType = typing.Callable[['Connection', config.ConfigurationType, 'Namespace'], typing.Coroutine[typing.Any, None, None]]
class Processes:
"""
This class is used to store the processes that are used by the tunnel.
@ -21,11 +23,11 @@ class Processes:
children: typing.List[
typing.Tuple['Connection', multiprocessing.Process, psutil.Process]
]
process: typing.Callable
process: ProcessType
cfg: config.ConfigurationType
ns: 'Namespace'
def __init__(self, process: typing.Callable, cfg: config.ConfigurationType, ns: 'Namespace') -> None:
def __init__(self, process: ProcessType, cfg: config.ConfigurationType, ns: 'Namespace') -> None:
self.children = []
self.process = process # type: ignore
self.cfg = cfg
@ -37,8 +39,8 @@ class Processes:
def add_child_pid(self):
own_conn, child_conn = multiprocessing.Pipe()
task = multiprocessing.Process(
target=curio.run,
args=(self.process, child_conn, self.cfg, self.ns)
target=asyncio.run,
args=(self.process(child_conn, self.cfg, self.ns),)
)
task.start()
logger.debug('ADD CHILD PID: %s', task.pid)

View File

@ -28,10 +28,11 @@
'''
@author: Adolfo Gómez, dkmaster at dkmon dot com
'''
import asyncio
import socket
import logging
import typing
import curio
import requests
from . import config
@ -40,7 +41,7 @@ from . import consts
if typing.TYPE_CHECKING:
from multiprocessing.managers import Namespace
import curio.io
import ssl
logger = logging.getLogger(__name__)
@ -109,29 +110,29 @@ class Proxy:
cfg, ticket, 'stop', {'sent': str(counter.sent), 'recv': str(counter.recv)}
) # Ignore results
@staticmethod
async def doProxy(
source: 'curio.io.Socket',
destination: 'curio.io.Socket',
counter: stats.StatsSingleCounter,
) -> None:
try:
while True:
data = await source.recv(consts.BUFFER_SIZE)
if not data:
break
await destination.sendall(data)
counter.add(len(data))
except Exception:
# Connection broken, same result as closed for us
# We must notice that i'ts easy that when closing one part of the tunnel,
# the other can break (due to some internal data), that's why even log is removed
# logger.info('CONNECTION LOST FROM %s to %s', source.getsockname(), destination.getpeername())
pass
# @staticmethod
# async def doProxy(
# source: 'curio.io.Socket',
# destination: 'curio.io.Socket',
# counter: stats.StatsSingleCounter,
# ) -> None:
# try:
# while True:
# data = await source.recv(consts.BUFFER_SIZE)
# if not data:
# break
# await destination.sendall(data)
# counter.add(len(data))
# except Exception:
# # Connection broken, same result as closed for us
# # We must notice that i'ts easy that when closing one part of the tunnel,
# # the other can break (due to some internal data), that's why even log is removed
# # logger.info('CONNECTION LOST FROM %s to %s', source.getsockname(), destination.getpeername())
# pass
# Method responsible of proxying requests
async def __call__(self, source, address: typing.Tuple[str, int]) -> None:
await self.proxy(source, address)
async def __call__(self, source: socket.socket, address: typing.Tuple[str, int], context: 'ssl.SSLContext') -> None:
await self.proxy(source, address, context)
async def stats(self, full: bool, source, address: typing.Tuple[str, int]) -> None:
# Check valid source ip
@ -153,17 +154,19 @@ class Proxy:
logger.debug('SENDING %s', v)
await source.sendall(v.encode() + b'\n')
async def proxy(self, source, address: typing.Tuple[str, int]) -> None:
async def proxy(self, source: socket.socket, address: typing.Tuple[str, int], context: 'ssl.SSLContext') -> None:
prettySource = address[0] # Get only source IP
prettyDest = ''
logger.info('CONNECT FROM %s', prettySource)
loop = asyncio.get_event_loop()
# Handshake correct in this point, start SSL connection
try:
command: bytes = await source.recv(consts.COMMAND_LENGTH)
command: bytes = await loop.sock_recv(source, consts.COMMAND_LENGTH)
if command == consts.COMMAND_TEST:
logger.info('COMMAND: TEST')
await source.sendall(b'OK')
await loop.sock_sendall(source, b'OK')
logger.info('TERMINATED %s', prettySource)
return

View File

@ -29,15 +29,16 @@
@author: Adolfo Gómez, dkmaster at dkmon dot com
'''
import multiprocessing
import socket
import time
import logging
import typing
import io
import asyncio
import ssl
import logging
import typing
import curio
from . import config
from . import consts
@ -146,15 +147,24 @@ async def getServerStats(detailed: bool = False) -> None:
try:
host = cfg.listen_address if cfg.listen_address != '0.0.0.0' else 'localhost'
sock = await curio.open_connection(
host, cfg.listen_port, ssl=context, server_hostname='localhost'
)
reader: asyncio.StreamReader
writer: asyncio.StreamWriter
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.connect((host, cfg.listen_port))
# Send HANDSHAKE
sock.sendall(consts.HANDSHAKE_V1)
# Ugrade connection to TLS
reader, writer = await asyncio.open_connection(sock=sock, ssl=context, server_hostname=host)
tmpdata = io.BytesIO()
cmd = consts.COMMAND_STAT if detailed else consts.COMMAND_INFO
async with sock:
await sock.sendall(consts.HANDSHAKE_V1 + cmd + cfg.secret.encode())
writer.write(cmd + cfg.secret.encode())
await writer.drain()
while True:
chunk = await sock.recv(consts.BUFFER_SIZE)
chunk = await reader.read(consts.BUFFER_SIZE)
if not chunk:
break
tmpdata.write(chunk)

View File

@ -1,7 +1,7 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright (c) 2021 Virtual Cable S.L.U.
# Copyright (c) 2021-2022 Virtual Cable S.L.U.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
@ -32,22 +32,18 @@
import os
import pwd
import sys
import asyncio
import argparse
import signal
import ssl
import socket
import logging
import typing
import curio
import setproctitle
from uds_tunnel import config
from uds_tunnel import proxy
from uds_tunnel import consts
from uds_tunnel import message
from uds_tunnel import stats
from uds_tunnel import processes
from uds_tunnel import config, proxy, consts, processes, stats
if typing.TYPE_CHECKING:
from multiprocessing.connection import Connection
@ -100,17 +96,34 @@ def setup_log(cfg: config.ConfigurationType) -> None:
async def tunnel_proc_async(
pipe: 'Connection', cfg: config.ConfigurationType, ns: 'Namespace'
) -> None:
def get_socket(pipe: 'Connection') -> typing.Tuple[typing.Optional[socket.SocketType], typing.Any]:
loop = asyncio.get_event_loop()
# Create event for flagging when we have new data
event = asyncio.Event()
loop.add_reader(pipe.fileno(), event.set)
tasks: typing.List[asyncio.Task] = []
async def get_socket() -> typing.Tuple[
typing.Optional[socket.socket], typing.Tuple[str, int]
]:
try:
while True:
msg: message.Message = pipe.recv()
if msg.command == message.Command.TUNNEL and msg.connection:
await event.wait()
# Clear back event, for next data
event.clear()
msg: typing.Optional[
typing.Tuple[socket.socket, typing.Tuple[str, int]]
] = pipe.recv()
if msg:
# Connection done, check for handshake
source, address = msg.connection
source, address = msg
try:
# First, ensure handshake (simple handshake) and command
data: bytes = source.recv(len(consts.HANDSHAKE_V1))
data: bytes = await loop.sock_recv(
source, len(consts.HANDSHAKE_V1)
)
if data != consts.HANDSHAKE_V1:
raise Exception() # Invalid handshake
@ -122,21 +135,19 @@ async def tunnel_proc_async(
source.close()
continue
return msg.connection
return msg
# Process other messages, and retry
except Exception:
logger.exception('Receiving data from parent process')
return None, None
return None, ('', 0)
async def run_server(
pipe: 'Connection', cfg: config.ConfigurationType, group: curio.TaskGroup
) -> None:
async def run_server() -> None:
# Instantiate a proxy redirector for this process (we only need one per process!!)
tunneler = proxy.Proxy(cfg, ns)
# Generate SSL context
context = curio.ssl.SSLContext(curio.ssl.PROTOCOL_TLS_SERVER) # type: ignore
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
context.load_cert_chain(cfg.ssl_certificate, cfg.ssl_certificate_key)
if cfg.ssl_ciphers:
@ -146,30 +157,27 @@ async def tunnel_proc_async(
context.load_dh_params(cfg.ssl_dhparam)
while True:
address = ('', '')
address: typing.Tuple[str, int] = ('', 0)
try:
sock, address = await curio.run_in_thread(get_socket, pipe)
sock, address = await get_socket()
if not sock:
break
logger.debug(
f'CONNECTION from {address!r} (pid: {os.getpid()})'
)
sock = await context.wrap_socket(
curio.io.Socket(sock), server_side=True # type: ignore
)
await group.spawn(tunneler, sock, address)
del sock
break # No more sockets, exit
logger.debug(f'CONNECTION from {address!r} (pid: {os.getpid()})')
tasks.append(asyncio.create_task(tunneler(sock, address, context)))
except Exception:
logger.error('NEGOTIATION ERROR from %s', address[0])
async with curio.TaskGroup() as tg:
await tg.spawn(run_server, pipe, cfg, tg)
await tg.join()
# Reap all of the children tasks as they complete
# async for task in tg:
# logger.debug(f'REMOVING async task {task!r}')
# task.joined = True
# del task
# create task for server
tasks.append(asyncio.create_task(run_server()))
while tasks:
tasks_number = len(tasks)
await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED)
# Remove finished tasks from list
del tasks[:tasks_number]
# Remove reader from event loop
loop.remove_reader(pipe.fileno())
def tunnel_main():
@ -200,7 +208,9 @@ def tunnel_main():
setup_log(cfg)
logger.info('Starting tunnel server on %s:%s', cfg.listen_address, cfg.listen_port)
logger.info(
'Starting tunnel server on %s:%s', cfg.listen_address, cfg.listen_port
)
setproctitle.setproctitle(f'UDSTunnel {cfg.listen_address}:{cfg.listen_port}')
# Create pid file
@ -213,7 +223,6 @@ def tunnel_main():
logger.error('MAIN: %s', e)
return
# Setup signal handlers
signal.signal(signal.SIGINT, stop_signal)
signal.signal(signal.SIGTERM, stop_signal)
@ -227,9 +236,7 @@ def tunnel_main():
try:
client, addr = sock.accept()
# Select BEST process for sending this new connection
prcs.best_child().send(
message.Message(message.Command.TUNNEL, (client, addr))
)
prcs.best_child().send((client, addr))
del client # Ensure socket is controlled on child process
except socket.timeout:
pass # Continue and retry
@ -259,9 +266,7 @@ def main() -> None:
group.add_argument(
'-t', '--tunnel', help='Starts the tunnel server', action='store_true'
)
group.add_argument(
'-r', '--rdp', help='RDP Tunnel for traffic accounting'
)
group.add_argument('-r', '--rdp', help='RDP Tunnel for traffic accounting')
group.add_argument(
'-s',
'--stats',
@ -281,9 +286,9 @@ def main() -> None:
elif args.rdp:
pass
elif args.detailed_stats:
curio.run(stats.getServerStats, True)
asyncio.run(stats.getServerStats(True))
elif args.stats:
curio.run(stats.getServerStats, False)
asyncio.run(stats.getServerStats(False))
else:
parser.print_help()