forked from shaba/openuds
Moving from curio to asyncio.
This commit is contained in:
parent
143b9b675b
commit
7d8ae689b5
@ -1,2 +1 @@
|
||||
curio>=1.4
|
||||
psutil>=5.7.3
|
||||
|
@ -1,45 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright (c) 2021 Virtual Cable S.L.U.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification,
|
||||
# are permitted provided that the following conditions are met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright notice,
|
||||
# this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
# * Neither the name of Virtual Cable S.L. nor the names of its contributors
|
||||
# may be used to endorse or promote products derived from this software
|
||||
# without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
'''
|
||||
@author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
'''
|
||||
import enum
|
||||
import socket
|
||||
import typing
|
||||
|
||||
class Command(enum.IntEnum):
|
||||
TUNNEL = 0
|
||||
STATS = 1
|
||||
|
||||
class Message:
|
||||
command: Command
|
||||
connection: typing.Optional[typing.Tuple[socket.socket, typing.Any]]
|
||||
|
||||
def __init__(self, command: Command, connection: typing.Optional[typing.Tuple[socket.socket, typing.Any]]):
|
||||
self.command = command
|
||||
self.connection = connection
|
@ -1,8 +1,8 @@
|
||||
import multiprocessing
|
||||
import asyncio
|
||||
import logging
|
||||
import typing
|
||||
|
||||
import curio
|
||||
import psutil
|
||||
|
||||
from . import config
|
||||
@ -13,6 +13,8 @@ if typing.TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ProcessType = typing.Callable[['Connection', config.ConfigurationType, 'Namespace'], typing.Coroutine[typing.Any, None, None]]
|
||||
|
||||
class Processes:
|
||||
"""
|
||||
This class is used to store the processes that are used by the tunnel.
|
||||
@ -21,11 +23,11 @@ class Processes:
|
||||
children: typing.List[
|
||||
typing.Tuple['Connection', multiprocessing.Process, psutil.Process]
|
||||
]
|
||||
process: typing.Callable
|
||||
process: ProcessType
|
||||
cfg: config.ConfigurationType
|
||||
ns: 'Namespace'
|
||||
|
||||
def __init__(self, process: typing.Callable, cfg: config.ConfigurationType, ns: 'Namespace') -> None:
|
||||
def __init__(self, process: ProcessType, cfg: config.ConfigurationType, ns: 'Namespace') -> None:
|
||||
self.children = []
|
||||
self.process = process # type: ignore
|
||||
self.cfg = cfg
|
||||
@ -37,8 +39,8 @@ class Processes:
|
||||
def add_child_pid(self):
|
||||
own_conn, child_conn = multiprocessing.Pipe()
|
||||
task = multiprocessing.Process(
|
||||
target=curio.run,
|
||||
args=(self.process, child_conn, self.cfg, self.ns)
|
||||
target=asyncio.run,
|
||||
args=(self.process(child_conn, self.cfg, self.ns),)
|
||||
)
|
||||
task.start()
|
||||
logger.debug('ADD CHILD PID: %s', task.pid)
|
||||
|
@ -28,10 +28,11 @@
|
||||
'''
|
||||
@author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
'''
|
||||
import asyncio
|
||||
import socket
|
||||
import logging
|
||||
import typing
|
||||
|
||||
import curio
|
||||
import requests
|
||||
|
||||
from . import config
|
||||
@ -40,7 +41,7 @@ from . import consts
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from multiprocessing.managers import Namespace
|
||||
import curio.io
|
||||
import ssl
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -109,29 +110,29 @@ class Proxy:
|
||||
cfg, ticket, 'stop', {'sent': str(counter.sent), 'recv': str(counter.recv)}
|
||||
) # Ignore results
|
||||
|
||||
@staticmethod
|
||||
async def doProxy(
|
||||
source: 'curio.io.Socket',
|
||||
destination: 'curio.io.Socket',
|
||||
counter: stats.StatsSingleCounter,
|
||||
) -> None:
|
||||
try:
|
||||
while True:
|
||||
data = await source.recv(consts.BUFFER_SIZE)
|
||||
if not data:
|
||||
break
|
||||
await destination.sendall(data)
|
||||
counter.add(len(data))
|
||||
except Exception:
|
||||
# Connection broken, same result as closed for us
|
||||
# We must notice that i'ts easy that when closing one part of the tunnel,
|
||||
# the other can break (due to some internal data), that's why even log is removed
|
||||
# logger.info('CONNECTION LOST FROM %s to %s', source.getsockname(), destination.getpeername())
|
||||
pass
|
||||
# @staticmethod
|
||||
# async def doProxy(
|
||||
# source: 'curio.io.Socket',
|
||||
# destination: 'curio.io.Socket',
|
||||
# counter: stats.StatsSingleCounter,
|
||||
# ) -> None:
|
||||
# try:
|
||||
# while True:
|
||||
# data = await source.recv(consts.BUFFER_SIZE)
|
||||
# if not data:
|
||||
# break
|
||||
# await destination.sendall(data)
|
||||
# counter.add(len(data))
|
||||
# except Exception:
|
||||
# # Connection broken, same result as closed for us
|
||||
# # We must notice that i'ts easy that when closing one part of the tunnel,
|
||||
# # the other can break (due to some internal data), that's why even log is removed
|
||||
# # logger.info('CONNECTION LOST FROM %s to %s', source.getsockname(), destination.getpeername())
|
||||
# pass
|
||||
|
||||
# Method responsible of proxying requests
|
||||
async def __call__(self, source, address: typing.Tuple[str, int]) -> None:
|
||||
await self.proxy(source, address)
|
||||
async def __call__(self, source: socket.socket, address: typing.Tuple[str, int], context: 'ssl.SSLContext') -> None:
|
||||
await self.proxy(source, address, context)
|
||||
|
||||
async def stats(self, full: bool, source, address: typing.Tuple[str, int]) -> None:
|
||||
# Check valid source ip
|
||||
@ -153,17 +154,19 @@ class Proxy:
|
||||
logger.debug('SENDING %s', v)
|
||||
await source.sendall(v.encode() + b'\n')
|
||||
|
||||
async def proxy(self, source, address: typing.Tuple[str, int]) -> None:
|
||||
async def proxy(self, source: socket.socket, address: typing.Tuple[str, int], context: 'ssl.SSLContext') -> None:
|
||||
prettySource = address[0] # Get only source IP
|
||||
prettyDest = ''
|
||||
logger.info('CONNECT FROM %s', prettySource)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# Handshake correct in this point, start SSL connection
|
||||
try:
|
||||
command: bytes = await source.recv(consts.COMMAND_LENGTH)
|
||||
command: bytes = await loop.sock_recv(source, consts.COMMAND_LENGTH)
|
||||
if command == consts.COMMAND_TEST:
|
||||
logger.info('COMMAND: TEST')
|
||||
await source.sendall(b'OK')
|
||||
await loop.sock_sendall(source, b'OK')
|
||||
logger.info('TERMINATED %s', prettySource)
|
||||
return
|
||||
|
||||
|
@ -29,15 +29,16 @@
|
||||
@author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
'''
|
||||
import multiprocessing
|
||||
import socket
|
||||
import time
|
||||
import logging
|
||||
import typing
|
||||
import io
|
||||
import asyncio
|
||||
import ssl
|
||||
import logging
|
||||
import typing
|
||||
|
||||
import curio
|
||||
|
||||
from . import config
|
||||
from . import consts
|
||||
@ -146,15 +147,24 @@ async def getServerStats(detailed: bool = False) -> None:
|
||||
|
||||
try:
|
||||
host = cfg.listen_address if cfg.listen_address != '0.0.0.0' else 'localhost'
|
||||
sock = await curio.open_connection(
|
||||
host, cfg.listen_port, ssl=context, server_hostname='localhost'
|
||||
)
|
||||
reader: asyncio.StreamReader
|
||||
writer: asyncio.StreamWriter
|
||||
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
sock.connect((host, cfg.listen_port))
|
||||
# Send HANDSHAKE
|
||||
sock.sendall(consts.HANDSHAKE_V1)
|
||||
# Ugrade connection to TLS
|
||||
reader, writer = await asyncio.open_connection(sock=sock, ssl=context, server_hostname=host)
|
||||
|
||||
tmpdata = io.BytesIO()
|
||||
cmd = consts.COMMAND_STAT if detailed else consts.COMMAND_INFO
|
||||
async with sock:
|
||||
await sock.sendall(consts.HANDSHAKE_V1 + cmd + cfg.secret.encode())
|
||||
|
||||
writer.write(cmd + cfg.secret.encode())
|
||||
await writer.drain()
|
||||
|
||||
while True:
|
||||
chunk = await sock.recv(consts.BUFFER_SIZE)
|
||||
chunk = await reader.read(consts.BUFFER_SIZE)
|
||||
if not chunk:
|
||||
break
|
||||
tmpdata.write(chunk)
|
||||
|
@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright (c) 2021 Virtual Cable S.L.U.
|
||||
# Copyright (c) 2021-2022 Virtual Cable S.L.U.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification,
|
||||
@ -32,22 +32,18 @@
|
||||
import os
|
||||
import pwd
|
||||
import sys
|
||||
import asyncio
|
||||
import argparse
|
||||
import signal
|
||||
import ssl
|
||||
import socket
|
||||
import logging
|
||||
import typing
|
||||
|
||||
import curio
|
||||
import setproctitle
|
||||
|
||||
|
||||
from uds_tunnel import config
|
||||
from uds_tunnel import proxy
|
||||
from uds_tunnel import consts
|
||||
from uds_tunnel import message
|
||||
from uds_tunnel import stats
|
||||
from uds_tunnel import processes
|
||||
from uds_tunnel import config, proxy, consts, processes, stats
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from multiprocessing.connection import Connection
|
||||
@ -100,17 +96,34 @@ def setup_log(cfg: config.ConfigurationType) -> None:
|
||||
async def tunnel_proc_async(
|
||||
pipe: 'Connection', cfg: config.ConfigurationType, ns: 'Namespace'
|
||||
) -> None:
|
||||
def get_socket(pipe: 'Connection') -> typing.Tuple[typing.Optional[socket.SocketType], typing.Any]:
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
# Create event for flagging when we have new data
|
||||
event = asyncio.Event()
|
||||
loop.add_reader(pipe.fileno(), event.set)
|
||||
|
||||
tasks: typing.List[asyncio.Task] = []
|
||||
|
||||
async def get_socket() -> typing.Tuple[
|
||||
typing.Optional[socket.socket], typing.Tuple[str, int]
|
||||
]:
|
||||
try:
|
||||
while True:
|
||||
msg: message.Message = pipe.recv()
|
||||
if msg.command == message.Command.TUNNEL and msg.connection:
|
||||
await event.wait()
|
||||
# Clear back event, for next data
|
||||
event.clear()
|
||||
msg: typing.Optional[
|
||||
typing.Tuple[socket.socket, typing.Tuple[str, int]]
|
||||
] = pipe.recv()
|
||||
if msg:
|
||||
# Connection done, check for handshake
|
||||
source, address = msg.connection
|
||||
source, address = msg
|
||||
|
||||
try:
|
||||
# First, ensure handshake (simple handshake) and command
|
||||
data: bytes = source.recv(len(consts.HANDSHAKE_V1))
|
||||
data: bytes = await loop.sock_recv(
|
||||
source, len(consts.HANDSHAKE_V1)
|
||||
)
|
||||
|
||||
if data != consts.HANDSHAKE_V1:
|
||||
raise Exception() # Invalid handshake
|
||||
@ -122,21 +135,19 @@ async def tunnel_proc_async(
|
||||
source.close()
|
||||
continue
|
||||
|
||||
return msg.connection
|
||||
return msg
|
||||
|
||||
# Process other messages, and retry
|
||||
except Exception:
|
||||
logger.exception('Receiving data from parent process')
|
||||
return None, None
|
||||
return None, ('', 0)
|
||||
|
||||
async def run_server(
|
||||
pipe: 'Connection', cfg: config.ConfigurationType, group: curio.TaskGroup
|
||||
) -> None:
|
||||
async def run_server() -> None:
|
||||
# Instantiate a proxy redirector for this process (we only need one per process!!)
|
||||
tunneler = proxy.Proxy(cfg, ns)
|
||||
|
||||
# Generate SSL context
|
||||
context = curio.ssl.SSLContext(curio.ssl.PROTOCOL_TLS_SERVER) # type: ignore
|
||||
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
context.load_cert_chain(cfg.ssl_certificate, cfg.ssl_certificate_key)
|
||||
|
||||
if cfg.ssl_ciphers:
|
||||
@ -146,30 +157,27 @@ async def tunnel_proc_async(
|
||||
context.load_dh_params(cfg.ssl_dhparam)
|
||||
|
||||
while True:
|
||||
address = ('', '')
|
||||
address: typing.Tuple[str, int] = ('', 0)
|
||||
try:
|
||||
sock, address = await curio.run_in_thread(get_socket, pipe)
|
||||
sock, address = await get_socket()
|
||||
if not sock:
|
||||
break
|
||||
logger.debug(
|
||||
f'CONNECTION from {address!r} (pid: {os.getpid()})'
|
||||
)
|
||||
sock = await context.wrap_socket(
|
||||
curio.io.Socket(sock), server_side=True # type: ignore
|
||||
)
|
||||
await group.spawn(tunneler, sock, address)
|
||||
del sock
|
||||
break # No more sockets, exit
|
||||
logger.debug(f'CONNECTION from {address!r} (pid: {os.getpid()})')
|
||||
tasks.append(asyncio.create_task(tunneler(sock, address, context)))
|
||||
except Exception:
|
||||
logger.error('NEGOTIATION ERROR from %s', address[0])
|
||||
|
||||
async with curio.TaskGroup() as tg:
|
||||
await tg.spawn(run_server, pipe, cfg, tg)
|
||||
await tg.join()
|
||||
# Reap all of the children tasks as they complete
|
||||
# async for task in tg:
|
||||
# logger.debug(f'REMOVING async task {task!r}')
|
||||
# task.joined = True
|
||||
# del task
|
||||
# create task for server
|
||||
tasks.append(asyncio.create_task(run_server()))
|
||||
|
||||
while tasks:
|
||||
tasks_number = len(tasks)
|
||||
await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED)
|
||||
# Remove finished tasks from list
|
||||
del tasks[:tasks_number]
|
||||
|
||||
# Remove reader from event loop
|
||||
loop.remove_reader(pipe.fileno())
|
||||
|
||||
|
||||
def tunnel_main():
|
||||
@ -200,7 +208,9 @@ def tunnel_main():
|
||||
|
||||
setup_log(cfg)
|
||||
|
||||
logger.info('Starting tunnel server on %s:%s', cfg.listen_address, cfg.listen_port)
|
||||
logger.info(
|
||||
'Starting tunnel server on %s:%s', cfg.listen_address, cfg.listen_port
|
||||
)
|
||||
setproctitle.setproctitle(f'UDSTunnel {cfg.listen_address}:{cfg.listen_port}')
|
||||
|
||||
# Create pid file
|
||||
@ -213,7 +223,6 @@ def tunnel_main():
|
||||
logger.error('MAIN: %s', e)
|
||||
return
|
||||
|
||||
|
||||
# Setup signal handlers
|
||||
signal.signal(signal.SIGINT, stop_signal)
|
||||
signal.signal(signal.SIGTERM, stop_signal)
|
||||
@ -227,9 +236,7 @@ def tunnel_main():
|
||||
try:
|
||||
client, addr = sock.accept()
|
||||
# Select BEST process for sending this new connection
|
||||
prcs.best_child().send(
|
||||
message.Message(message.Command.TUNNEL, (client, addr))
|
||||
)
|
||||
prcs.best_child().send((client, addr))
|
||||
del client # Ensure socket is controlled on child process
|
||||
except socket.timeout:
|
||||
pass # Continue and retry
|
||||
@ -259,9 +266,7 @@ def main() -> None:
|
||||
group.add_argument(
|
||||
'-t', '--tunnel', help='Starts the tunnel server', action='store_true'
|
||||
)
|
||||
group.add_argument(
|
||||
'-r', '--rdp', help='RDP Tunnel for traffic accounting'
|
||||
)
|
||||
group.add_argument('-r', '--rdp', help='RDP Tunnel for traffic accounting')
|
||||
group.add_argument(
|
||||
'-s',
|
||||
'--stats',
|
||||
@ -281,9 +286,9 @@ def main() -> None:
|
||||
elif args.rdp:
|
||||
pass
|
||||
elif args.detailed_stats:
|
||||
curio.run(stats.getServerStats, True)
|
||||
asyncio.run(stats.getServerStats(True))
|
||||
elif args.stats:
|
||||
curio.run(stats.getServerStats, False)
|
||||
asyncio.run(stats.getServerStats(False))
|
||||
else:
|
||||
parser.print_help()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user