mirror of
https://github.com/dkmstr/openuds.git
synced 2025-03-11 00:58:39 +03:00
some fixes to make tests work as they are expected to
This commit is contained in:
parent
99cd7030e0
commit
c33c1501f5
@ -68,11 +68,12 @@ ignored-modules=
|
||||
# Python code to execute, usually for sys.path manipulation such as
|
||||
# pygtk.require().
|
||||
#init-hook=
|
||||
init-hook='import sys, os; sys.path.append(os.path.join(os.getcwd(), "src"))'
|
||||
|
||||
# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
|
||||
# number of processors available to use, and will cap the count on Windows to
|
||||
# avoid hangs.
|
||||
jobs=1
|
||||
jobs=4
|
||||
|
||||
# Control the amount of potential inferred values when inferring a single
|
||||
# object. This can help the performance when dealing with large functions or
|
||||
|
@ -73,7 +73,7 @@ init-hook='import sys, os; sys.path.append(os.path.join(os.getcwd(), "src"))'
|
||||
# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
|
||||
# number of processors available to use, and will cap the count on Windows to
|
||||
# avoid hangs.
|
||||
jobs=1
|
||||
jobs=4
|
||||
|
||||
# Control the amount of potential inferred values when inferring a single
|
||||
# object. This can help the performance when dealing with large functions or
|
||||
|
@ -74,11 +74,17 @@ class Proxy:
|
||||
# Handshake correct in this point, upgrade the connection to TSL and let
|
||||
# the protocol controller do the rest
|
||||
|
||||
# Store source ip and port, for logging purposes in case of error
|
||||
src_ip, src_port = (source.getpeername() if source else ('Unknown', 0))[:2] # May be ipv4 or ipv6, so we get only first two elements
|
||||
|
||||
# Upgrade connection to SSL, and use asyncio to handle the rest
|
||||
tun: typing.Optional[tunnel.TunnelProtocol] = None
|
||||
try:
|
||||
tun = tunnel.TunnelProtocol(self)
|
||||
# (connect accepted loop not present on AbastractEventLoop definition < 3.10), that's why we use ignore
|
||||
await loop.connect_accepted_socket( # type: ignore
|
||||
lambda: tunnel.TunnelProtocol(self), source, ssl=context
|
||||
lambda: tun, source, ssl=context,
|
||||
ssl_handshake_timeout=3,
|
||||
)
|
||||
|
||||
# Wait for connection to be closed
|
||||
@ -86,6 +92,11 @@ class Proxy:
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass # Return on cancel
|
||||
except Exception as e:
|
||||
# Any other exception, ensure we close the connection
|
||||
logger.error('ERROR on %s:%s: %s', src_ip, src_port, e)
|
||||
if tun:
|
||||
tun.close_connection()
|
||||
|
||||
logger.debug('Proxy finished')
|
||||
|
||||
|
@ -97,8 +97,6 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
# We start processing command
|
||||
# After command, we can process stats or do_proxy, that is the "normal" operation
|
||||
self.runner = self.do_command
|
||||
# Set starting timeout task, se we dont get hunged on connections without data (or insufficient data)
|
||||
self.set_timeout(self.owner.cfg.command_timeout)
|
||||
|
||||
def process_open(self) -> None:
|
||||
# Open Command has the ticket behind it
|
||||
@ -158,7 +156,7 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
self.transport.write(b'OK')
|
||||
self.stats_manager.increment_connections() # Increment connections counters
|
||||
except Exception as e:
|
||||
logger.error('Error opening connection: %s', e)
|
||||
logger.error('CONNECTION FAILED: %s', e)
|
||||
self.close_connection()
|
||||
|
||||
# add open other side to the loop
|
||||
@ -276,9 +274,10 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
|
||||
def close_connection(self):
|
||||
try:
|
||||
self.clean_timeout() # If a timeout is set, clean it
|
||||
if not self.transport.is_closing():
|
||||
self.transport.close()
|
||||
except Exception:
|
||||
except Exception: # nosec: best effort
|
||||
pass # Ignore errors
|
||||
|
||||
def notify_end(self):
|
||||
@ -304,6 +303,10 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
|
||||
def connection_made(self, transport: 'asyncio.transports.BaseTransport') -> None:
|
||||
# We know for sure that the transport is a Transport.
|
||||
|
||||
# Set starting timeout task, se we dont get hunged on connections without data (or insufficient data)
|
||||
self.set_timeout(self.owner.cfg.command_timeout)
|
||||
|
||||
self.transport = typing.cast('asyncio.transports.Transport', transport)
|
||||
# Get source
|
||||
self.source = self.transport.get_extra_info('peername')
|
||||
|
@ -45,13 +45,14 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
|
||||
try:
|
||||
import uvloop
|
||||
import uvloop # type: ignore
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
except ImportError:
|
||||
pass # no uvloop support
|
||||
|
||||
try:
|
||||
import setproctitle
|
||||
import setproctitle # type: ignore
|
||||
except ImportError:
|
||||
setproctitle = None # type: ignore
|
||||
|
||||
@ -111,7 +112,12 @@ async def tunnel_proc_async(pipe: 'Connection', cfg: config.ConfigurationType, n
|
||||
|
||||
def add_autoremovable_task(task: asyncio.Task) -> None:
|
||||
tasks.append(task)
|
||||
task.add_done_callback(tasks.remove)
|
||||
|
||||
def remove_task(task: asyncio.Task) -> None:
|
||||
logger.debug('Removing task %s', task)
|
||||
tasks.remove(task)
|
||||
|
||||
task.add_done_callback(remove_task)
|
||||
|
||||
def get_socket() -> typing.Tuple[typing.Optional[socket.socket], typing.Optional[typing.Tuple[str, int]]]:
|
||||
try:
|
||||
@ -174,7 +180,9 @@ async def tunnel_proc_async(pipe: 'Connection', cfg: config.ConfigurationType, n
|
||||
break # No more sockets, exit
|
||||
logger.debug('CONNECTION from %s (pid: %s)', address, os.getpid())
|
||||
# Due to proxy contains an "event" to stop, we need to create a new one for each connection
|
||||
add_autoremovable_task(asyncio.create_task(proxy.Proxy(cfg, ns)(sock, context)))
|
||||
add_autoremovable_task(
|
||||
asyncio.create_task(proxy.Proxy(cfg, ns)(sock, context), name=f'proxy-{address}')
|
||||
)
|
||||
except asyncio.CancelledError: # pylint: disable=try-except-raise
|
||||
raise # Stop, but avoid generic exception
|
||||
except Exception:
|
||||
@ -184,7 +192,7 @@ async def tunnel_proc_async(pipe: 'Connection', cfg: config.ConfigurationType, n
|
||||
|
||||
# create task for server
|
||||
|
||||
add_autoremovable_task(asyncio.create_task(run_server()))
|
||||
add_autoremovable_task(asyncio.create_task(run_server(), name='server'))
|
||||
|
||||
try:
|
||||
while tasks and not do_stop.is_set():
|
||||
@ -195,6 +203,9 @@ async def tunnel_proc_async(pipe: 'Connection', cfg: config.ConfigurationType, n
|
||||
except asyncio.CancelledError:
|
||||
logger.info('Task cancelled')
|
||||
do_stop.set() # ensure we stop
|
||||
except Exception:
|
||||
logger.exception('Error in main loop')
|
||||
do_stop.set()
|
||||
|
||||
logger.debug('Out of loop, stopping tasks: %s, running: %s', tasks, do_stop.is_set())
|
||||
|
||||
@ -236,8 +247,10 @@ def tunnel_main(args: 'argparse.Namespace') -> None:
|
||||
# Try to bind to port as running user
|
||||
# Wait for socket incoming connections and spread them
|
||||
socket.setdefaulttimeout(3.0) # So we can check for stop from time to time and not block forever
|
||||
af_inet = socket.AF_INET6 if args.ipv6 or cfg.ipv6 or ':' in cfg.listen_address else socket.AF_INET
|
||||
sock = socket.socket(af_inet, socket.SOCK_STREAM)
|
||||
sock = socket.socket(
|
||||
socket.AF_INET6 if args.ipv6 or cfg.ipv6 or ':' in cfg.listen_address else socket.AF_INET,
|
||||
socket.SOCK_STREAM,
|
||||
)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True)
|
||||
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
||||
# We will not reuse port, we only want a UDS tunnel server running on a port
|
||||
|
@ -45,7 +45,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TestUDSTunnelApp(IsolatedAsyncioTestCase):
|
||||
async def client_task(self, host: str, tunnel_port: int, remote_port: int) -> None:
|
||||
async def client_task(self, host: str, tunnel_port: int, remote_port: int, use_tunnel_handshake: bool = False) -> None:
|
||||
received: bytes = b''
|
||||
callback_invoked: asyncio.Event = asyncio.Event()
|
||||
# Data sent will be received by server
|
||||
@ -77,7 +77,7 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
|
||||
cfg.listen_port = tunnel_port
|
||||
|
||||
async with tuntools.open_tunnel_client(
|
||||
cfg, local_port=remote_port + 10000, use_tunnel_handshake=True
|
||||
cfg, local_port=remote_port + 10000, use_tunnel_handshake=use_tunnel_handshake
|
||||
) as (
|
||||
creader,
|
||||
cwriter,
|
||||
@ -120,7 +120,7 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
|
||||
self.assertEqual(received, test_str)
|
||||
|
||||
async def test_app_concurrency(self) -> None:
|
||||
concurrent_tasks = 512
|
||||
concurrent_tasks = 1024
|
||||
fake_broker_port = 20000
|
||||
tunnel_server_port = fake_broker_port + 1
|
||||
remote_port = fake_broker_port + 2
|
||||
@ -158,7 +158,7 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
|
||||
) as process: # pylint: disable=unused-variable
|
||||
# Create a "bunch" of clients
|
||||
tasks = [
|
||||
asyncio.create_task(self.client_task(host, tunnel_server_port, remote_port + i))
|
||||
asyncio.create_task(self.client_task(host, tunnel_server_port, remote_port + i, use_tunnel_handshake=True))
|
||||
async for i in tools.waitable_range(concurrent_tasks)
|
||||
]
|
||||
|
||||
|
@ -62,7 +62,7 @@ class TestTunnel(IsolatedAsyncioTestCase):
|
||||
# Commands are 4 bytes length, try with less and more invalid commands
|
||||
for i in range(0, 100, 10):
|
||||
# Set timeout to 1 seconds
|
||||
bad_cmd = bytes(random.randint(0, 255) for _ in range(i)) # Some garbage
|
||||
bad_cmd = bytes(random.randint(0, 255) for _ in range(i)) # nosec: Some garbage, not security related
|
||||
logger.info('Testing invalid command with %s', bad_cmd)
|
||||
async with tuntools.create_test_tunnel(callback=lambda x: None, port=7770, remote_port=54555, command_timeout=0.1) as cfg:
|
||||
logger_mock = mock.MagicMock()
|
||||
@ -95,18 +95,18 @@ class TestTunnel(IsolatedAsyncioTestCase):
|
||||
def test_tunnel_invalid_handshake(self) -> None:
|
||||
# Not async test, executed on main thread without event loop
|
||||
# Pipe for testing
|
||||
own_conn, other_conn = multiprocessing.Pipe()
|
||||
own_conn, other_conn = multiprocessing.Pipe() # pylint: disable=unused-variable
|
||||
|
||||
# Some random data to send on each test, all invalid
|
||||
# 0 bytes will make timeout to be reached
|
||||
for i in [i for i in range(10)] + [i for i in range(100, 10000, 100)]:
|
||||
for i in list(range(10)) + list(range(100, 10000, 100)):
|
||||
# Create a simple socket for testing
|
||||
rsock, wsock = socket.socketpair()
|
||||
# Set read timeout to 1 seconds
|
||||
rsock.settimeout(3)
|
||||
|
||||
# Set timeout to 1 seconds
|
||||
bad_handshake = bytes(random.randint(0, 255) for _ in range(i))
|
||||
bad_handshake = bytes(random.randint(0, 255) for _ in range(i)) # nosec not for security
|
||||
logger_mock = mock.MagicMock()
|
||||
with mock.patch('udstunnel.logger', logger_mock):
|
||||
wsock.sendall(bad_handshake)
|
||||
@ -138,4 +138,3 @@ class TestTunnel(IsolatedAsyncioTestCase):
|
||||
# and that other_conn has received a ('host', 'port') tuple
|
||||
# recv()[0] will be a copy of the socket, we don't care about it
|
||||
self.assertEqual(other_conn.recv()[1], ('host', 'port'))
|
||||
|
||||
|
@ -49,7 +49,6 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
|
||||
async def test_run_app_help(self) -> None:
|
||||
# Executes the app with --help
|
||||
async with tuntools.tunnel_app_runner(args=['--help']) as process:
|
||||
|
||||
stdout, stderr = await process.communicate()
|
||||
self.assertEqual(process.returncode, 0, f'{stdout!r} {stderr!r}')
|
||||
self.assertEqual(stderr, b'')
|
||||
@ -57,7 +56,7 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
|
||||
|
||||
async def test_tunnel_fail_cmd(self) -> None:
|
||||
# Test on ipv4 and ipv6
|
||||
for host in ('127.0.0.1', '::1'):
|
||||
for host in ('::1', '127.0.0.1'):
|
||||
# Remote is not really important in this tests, will fail before using it
|
||||
async with tuntools.create_tunnel_proc(
|
||||
host,
|
||||
@ -65,18 +64,17 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
|
||||
'127.0.0.1',
|
||||
13579, # A port not used by any other test
|
||||
command_timeout=0.1,
|
||||
) as (cfg, queue):
|
||||
) as (cfg, queue): # pylint: disable=unused-variable
|
||||
for i in range(0, 8192, 128):
|
||||
# Set timeout to 1 seconds
|
||||
bad_cmd = bytes(
|
||||
random.randint(0, 255) for _ in range(i)
|
||||
) # Some garbage
|
||||
logger.info(f'Testing invalid command with {bad_cmd!r}')
|
||||
bad_cmd = bytes(random.randint(0, 255) for _ in range(i)) # nosec: Some garbage
|
||||
logger.info('Testing invalid command with %s', bad_cmd)
|
||||
# On full, we need the handshake to be done, before connecting
|
||||
# Our "test" server will simple "eat" the handshake, but we need to do it
|
||||
async with tuntools.open_tunnel_client(
|
||||
cfg, use_tunnel_handshake=True
|
||||
) as (creader, cwriter):
|
||||
async with tuntools.open_tunnel_client(cfg) as (
|
||||
creader,
|
||||
cwriter,
|
||||
):
|
||||
cwriter.write(bad_cmd)
|
||||
await cwriter.drain()
|
||||
# Read response
|
||||
@ -95,13 +93,14 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
|
||||
7891,
|
||||
'127.0.0.1',
|
||||
13581,
|
||||
) as (cfg, queue):
|
||||
for i in range(10): # Several times
|
||||
) as (cfg, queue): # pylint: disable=unused-variable
|
||||
for _ in range(10): # Several times
|
||||
# On full, we need the handshake to be done, before connecting
|
||||
# Our "test" server will simple "eat" the handshake, but we need to do it
|
||||
async with tuntools.open_tunnel_client(
|
||||
cfg, use_tunnel_handshake=True
|
||||
) as (creader, cwriter):
|
||||
async with tuntools.open_tunnel_client(cfg) as (
|
||||
creader,
|
||||
cwriter,
|
||||
):
|
||||
cwriter.write(consts.COMMAND_TEST)
|
||||
await cwriter.drain()
|
||||
# Read response
|
||||
@ -119,7 +118,7 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
|
||||
server.host,
|
||||
server.port,
|
||||
command_timeout=0.1,
|
||||
) as (cfg, queue):
|
||||
) as (cfg, queue): # pylint: disable=unused-variable
|
||||
for i in range(
|
||||
0, consts.TICKET_LENGTH - 1, 4
|
||||
): # All will fail. Any longer will be processed, and mock will return correct don't matter the ticket
|
||||
@ -127,9 +126,10 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
|
||||
ticket = tuntools.get_correct_ticket(i)
|
||||
# On full, we need the handshake to be done, before connecting
|
||||
# Our "test" server will simple "eat" the handshake, but we need to do it
|
||||
async with tuntools.open_tunnel_client(
|
||||
cfg, use_tunnel_handshake=True
|
||||
) as (creader, cwriter):
|
||||
async with tuntools.open_tunnel_client(cfg) as (
|
||||
creader,
|
||||
cwriter,
|
||||
):
|
||||
cwriter.write(consts.COMMAND_OPEN)
|
||||
# fake ticket, consts.TICKET_LENGTH bytes long, letters and numbers. Use a random ticket,
|
||||
cwriter.write(ticket)
|
||||
@ -149,13 +149,11 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
|
||||
received += data
|
||||
# if data contains EOS marcker ('STREAM_END'), we are done
|
||||
if b'STREAM_END' in data:
|
||||
callback_invoked.set()
|
||||
callback_invoked.set() # pylint: disable=cell-var-from-loop
|
||||
|
||||
# Remote is important in this tests
|
||||
# create a remote server, use a different port than the tunnel fail test, because tests may run in parallel
|
||||
async with tools.AsyncTCPServer(
|
||||
host=host, port=5445, callback=callback
|
||||
) as server:
|
||||
async with tools.AsyncTCPServer(host=host, port=5445, callback=callback) as server:
|
||||
for tunnel_host in ('127.0.0.1', '::1'):
|
||||
async with tuntools.create_tunnel_proc(
|
||||
tunnel_host,
|
||||
@ -165,17 +163,19 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
|
||||
use_fake_http_server=True,
|
||||
) as (cfg, queue):
|
||||
# Ensure queue is not none but an asyncio.Queue
|
||||
# Note, this also let's mypy know that queue is not None after this point
|
||||
if queue is None:
|
||||
raise AssertionError('Queue is None')
|
||||
|
||||
for i in range(16):
|
||||
|
||||
for _ in range(16):
|
||||
# Create a random ticket with valid format
|
||||
ticket = tuntools.get_correct_ticket()
|
||||
# On full, we need the handshake to be done, before connecting
|
||||
# Our "test" server will simple "eat" the handshake, but we need to do it
|
||||
async with tuntools.open_tunnel_client(
|
||||
cfg, use_tunnel_handshake=True
|
||||
) as (creader, cwriter):
|
||||
async with tuntools.open_tunnel_client(cfg) as (
|
||||
creader,
|
||||
cwriter,
|
||||
):
|
||||
cwriter.write(consts.COMMAND_OPEN)
|
||||
# fake ticket, consts.TICKET_LENGTH bytes long, letters and numbers. Use a random ticket,
|
||||
cwriter.write(ticket)
|
||||
@ -199,14 +199,14 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
|
||||
else:
|
||||
should_be_url = f'/stop/{cfg.uds_token}'.encode()
|
||||
self.assertIn(should_be_url, queue_item)
|
||||
|
||||
|
||||
# Ensure user agent is correct
|
||||
|
||||
# Data sent will be received by server
|
||||
# One single write will ensure all data is on same packet
|
||||
test_str = (
|
||||
b'Some Random Data'
|
||||
+ bytes(random.randint(0, 255) for _ in range(8192))
|
||||
+ bytes(random.randint(0, 255) for _ in range(8192)) # nosec: some random data, not used for security
|
||||
+ b'STREAM_END'
|
||||
)
|
||||
# Clean received data
|
||||
@ -233,9 +233,10 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
|
||||
ticket = tuntools.get_correct_ticket()
|
||||
# On full, we need the handshake to be done, before connecting
|
||||
# Our "test" server will simple "eat" the handshake, but we need to do it
|
||||
async with tuntools.open_tunnel_client(
|
||||
cfg, use_tunnel_handshake=True
|
||||
) as (creader, cwriter):
|
||||
async with tuntools.open_tunnel_client(cfg) as (
|
||||
creader,
|
||||
cwriter,
|
||||
):
|
||||
cwriter.write(consts.COMMAND_OPEN)
|
||||
# fake ticket, consts.TICKET_LENGTH bytes long, letters and numbers. Use a random ticket,
|
||||
cwriter.write(ticket)
|
||||
|
@ -32,18 +32,21 @@ import asyncio
|
||||
import os
|
||||
import ssl
|
||||
import typing
|
||||
import collections.abc
|
||||
import socket
|
||||
import aiohttp
|
||||
import logging
|
||||
from unittest import mock
|
||||
|
||||
import aiohttp
|
||||
|
||||
from . import certs
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import collections.abc
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class AsyncMock(mock.MagicMock):
|
||||
async def __call__(self, *args, **kwargs):
|
||||
async def __call__(self, *args, **kwargs): # pylint: disable=invalid-overridden-method
|
||||
return super().__call__(*args, **kwargs)
|
||||
|
||||
|
||||
@ -69,7 +72,7 @@ class AsyncHttpServer:
|
||||
self._server = None
|
||||
self._response = response
|
||||
if use_ssl:
|
||||
self._ssl_ctx, self._ssl_cert_file, pwd = certs.sslContext(host)
|
||||
self._ssl_ctx, self._ssl_cert_file, pwd = certs.sslContext() # pylint: disable=unused-variable
|
||||
else:
|
||||
self._ssl_ctx = None
|
||||
self._ssl_cert_file = None
|
||||
@ -150,15 +153,18 @@ class AsyncTCPServer:
|
||||
await self._processor(reader, writer)
|
||||
return
|
||||
while True:
|
||||
data = await reader.read(4096) # Care with this and tunnel handshake on testings...
|
||||
if not data:
|
||||
break
|
||||
try:
|
||||
data = await reader.read(4096) # Care with this and tunnel handshake on testings...
|
||||
if not data:
|
||||
break
|
||||
|
||||
resp = self._callback(data) if self._callback else self._response
|
||||
resp = self._callback(data) if self._callback else self._response
|
||||
|
||||
if resp is not None:
|
||||
writer.write(resp)
|
||||
await writer.drain()
|
||||
if resp is not None:
|
||||
writer.write(resp)
|
||||
await writer.drain()
|
||||
except Exception as e:
|
||||
logger.exception('Exception %s on %s', e, self._name)
|
||||
|
||||
async def __aenter__(self) -> 'AsyncTCPServer':
|
||||
if ':' in self.host:
|
||||
@ -196,7 +202,7 @@ async def wait_for_port(host: str, port: int) -> None:
|
||||
except ConnectionRefusedError:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def waitable_range(len: int, wait: float = 0.0001) -> 'collections.abc.AsyncGenerator[int, None]':
|
||||
for i in range(len):
|
||||
async def waitable_range(size: int, wait: float = 0.0001) -> 'collections.abc.AsyncGenerator[int, None]':
|
||||
for i in range(size):
|
||||
await asyncio.sleep(wait)
|
||||
yield i
|
||||
|
@ -146,7 +146,7 @@ async def create_tunnel_proc(
|
||||
if response is None:
|
||||
response = conf.UDS_GET_TICKET_RESPONSE(remote_host, remote_port)
|
||||
|
||||
port = random.randint(8000, 58000) # nosec Just a random port
|
||||
port = random.randint(20000, 40000) # nosec Just a random port
|
||||
hhost = f'[{listen_host}]' if ':' in listen_host else listen_host
|
||||
args = {
|
||||
'uds_server': f'http://{hhost}:{port}/uds/rest',
|
||||
@ -205,34 +205,37 @@ async def create_tunnel_proc(
|
||||
# Create the tunnel task
|
||||
task = asyncio.create_task(udstunnel.tunnel_proc_async(other_end, cfg, global_stats.ns))
|
||||
|
||||
# Create a small asyncio server that reads the handshake,
|
||||
# and sends the socket to the tunnel_proc_async using the pipe
|
||||
# the pipe message will be typing.Tuple[socket.socket, typing.Tuple[str, int]]
|
||||
# socket and address
|
||||
async def client_connected_cb(reader, writer):
|
||||
# Read the handshake
|
||||
# Note: We need a small wait on sender, because this is a bufferedReader
|
||||
# so it will read the handshake and the first bytes of the data (that is the ssl handshake)
|
||||
_ = await reader.read(len(consts.HANDSHAKE_V1))
|
||||
# For testing, we ignore the handshake value
|
||||
# Send the socket to the tunnel
|
||||
own_end.send(
|
||||
(
|
||||
writer.get_extra_info('socket').dup(),
|
||||
writer.get_extra_info('peername'),
|
||||
)
|
||||
)
|
||||
# Close the socket
|
||||
writer.close()
|
||||
# Server listening for connections
|
||||
server_socket = socket.socket(socket.AF_INET6 if ':' in listen_host else socket.AF_INET, socket.SOCK_STREAM)
|
||||
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) # Allow reuse of address
|
||||
server_socket.bind((listen_host, listen_port))
|
||||
server_socket.listen(8)
|
||||
server_socket.setblocking(False)
|
||||
|
||||
server = await asyncio.start_server(
|
||||
client_connected_cb,
|
||||
listen_host,
|
||||
listen_port,
|
||||
)
|
||||
async def server():
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
while True:
|
||||
client, addr = await loop.sock_accept(server_socket)
|
||||
# Send the socket to the tunnel
|
||||
own_end.send((client, addr))
|
||||
except asyncio.CancelledError:
|
||||
pass # We are closing
|
||||
except Exception:
|
||||
logger.exception('Exception in server')
|
||||
# Close the socket
|
||||
server_socket.close()
|
||||
|
||||
|
||||
# Create the middleware task
|
||||
server_task = asyncio.create_task(server())
|
||||
try:
|
||||
yield cfg, possible_queue
|
||||
finally:
|
||||
# Cancel the middleware task
|
||||
server_task.cancel()
|
||||
logger.info('Server closed')
|
||||
|
||||
# Close the pipe (both ends)
|
||||
own_end.close()
|
||||
|
||||
@ -240,10 +243,6 @@ async def create_tunnel_proc(
|
||||
# wait for the task to finish
|
||||
await task
|
||||
|
||||
server.close()
|
||||
await server.wait_closed()
|
||||
logger.info('Server closed')
|
||||
|
||||
# Ensure log file are removed
|
||||
rootlog = logging.getLogger()
|
||||
for h in rootlog.handlers:
|
||||
@ -385,7 +384,7 @@ async def open_tunnel_client(
|
||||
|
||||
if not use_tunnel_handshake:
|
||||
reader, writer = await asyncio.open_connection(
|
||||
cfg.listen_address, cfg.listen_port, ssl=context, family=family
|
||||
cfg.listen_address, cfg.listen_port, ssl=context, family=family, ssl_handshake_timeout=1
|
||||
)
|
||||
else:
|
||||
# Open the socket, send handshake and then upgrade to ssl, non blocking
|
||||
@ -402,7 +401,6 @@ async def open_tunnel_client(
|
||||
# (reads chunks of 4096 bytes). If we don't wait, the handshake will be readed
|
||||
# and part or all of ssl handshake also.
|
||||
# With uvloop this seems to be not needed, but with asyncio it is.
|
||||
await asyncio.sleep(0.05)
|
||||
# upgrade to ssl
|
||||
reader, writer = await asyncio.open_connection(
|
||||
sock=sock, ssl=context, server_hostname=cfg.listen_address
|
||||
@ -451,8 +449,15 @@ async def tunnel_app_runner(
|
||||
finally:
|
||||
# Ensure the process is terminated
|
||||
if process.returncode is None:
|
||||
logger.info('Terminating tunnel process %s', process.pid)
|
||||
process.terminate()
|
||||
await process.wait()
|
||||
await asyncio.wait_for(process.wait(), 10)
|
||||
# Ensure the process is terminated
|
||||
if process.returncode is None:
|
||||
logger.info('Killing tunnel process %s', process.pid)
|
||||
process.kill()
|
||||
await asyncio.wait_for(process.wait(), 10)
|
||||
logger.info('Tunnel process %s terminated', process.pid)
|
||||
|
||||
|
||||
def get_correct_ticket(length: int = consts.TICKET_LENGTH, *, prefix: typing.Optional[str] = None) -> bytes:
|
||||
|
Loading…
x
Reference in New Issue
Block a user