1
0
mirror of https://github.com/dkmstr/openuds.git synced 2024-12-25 23:21:41 +03:00

Added ThreadPoolExecutor to check handshakes...

This commit is contained in:
Adolfo Gómez García 2022-03-30 15:44:49 +02:00
parent 69192a2a1b
commit 0f3f50f63c

View File

@ -38,7 +38,7 @@ import signal
import ssl
import socket
import logging
import threading
from concurrent.futures import ThreadPoolExecutor
import typing
import setproctitle
@ -129,27 +129,6 @@ async def tunnel_proc_async(
if cfg.ssl_dhparam:
context.load_dh_params(cfg.ssl_dhparam)
async def processSocket(ssock: socket.socket) -> None:
sock = curio.io.Socket(ssock)
try:
# First, ensure handshake (simple handshake) and command
async with curio.timeout_after(3): # type: ignore
data = await sock.recv(len(consts.HANDSHAKE_V1))
if data != consts.HANDSHAKE_V1:
raise Exception(data) # Invalid handshake
except (curio.errors.CancelledError, Exception) as e:
logger.error('HANDSHAKE from %s (%s)', address, 'timeout' if isinstance(e, curio.errors.CancelledError) else e)
# Close Source and continue
await sock.close()
return
sslsock = await context.wrap_socket(
sock, server_side=True # type: ignore
)
await group.spawn(tunneler, sslsock, address)
del sslsock
while True:
address: typing.Tuple[str, int] = ('', 0)
try:
@ -170,6 +149,7 @@ async def tunnel_proc_async(
# Remove finished tasks from list
del tasks[:tasks_number]
logger.info('PROCESS %s stopped', os.getpid())
def process_connection(
client: socket.socket, addr: typing.Tuple[str, str], conn: 'Connection'
@ -188,8 +168,6 @@ def process_connection(
# Close Source and continue
client.close()
logger.info('PROCESS %s stopped', os.getpid())
def tunnel_main():
cfg = config.read()
@ -244,25 +222,24 @@ def tunnel_main():
prcs = processes.Processes(tunnel_proc_async, cfg, stats_collector.ns)
try:
while not do_stop:
try:
client, addr = sock.accept()
logger.info('CONNECTION from %s', addr)
with ThreadPoolExecutor(max_workers=256) as executor:
try:
while not do_stop:
try:
client, addr = sock.accept()
logger.info('CONNECTION from %s', addr)
# Check if we have reached the max number of connections
# First part is checked on a thread, if HANDSHAKE is valid
# we will send socket to process pool
threading.Thread(
target=process_connection, args=(client, addr, prcs.best_child())
).start()
except socket.timeout:
pass # Continue and retry
except Exception as e:
logger.error('LOOP: %s', e)
except Exception as e:
sys.stderr.write(f'Error: {e}\n')
logger.error('MAIN: %s', e)
# Check if we have reached the max number of connections
# First part is checked on a thread, if HANDSHAKE is valid
# we will send socket to process pool
executor.submit(process_connection, client, addr, prcs.best_child())
except socket.timeout:
pass # Continue and retry
except Exception as e:
logger.error('LOOP: %s', e)
except Exception as e:
sys.stderr.write(f'Error: {e}\n')
logger.error('MAIN: %s', e)
if sock:
sock.close()