1
0
mirror of https://github.com/dkmstr/openuds.git synced 2025-03-12 04:58:34 +03:00

fixed bad ssl handshake on tests

This commit is contained in:
Adolfo Gómez García 2023-05-21 16:48:32 +02:00
parent c33c1501f5
commit 2da927d82b
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
4 changed files with 44 additions and 10 deletions

View File

@ -97,6 +97,9 @@ class Proxy:
logger.error('ERROR on %s:%s: %s', src_ip, src_port, e)
if tun:
tun.close_connection()
# Also, ensure socket is closed
if source:
del source
logger.debug('Proxy finished')

View File

@ -275,10 +275,12 @@ class TunnelProtocol(asyncio.Protocol):
def close_connection(self):
try:
self.clean_timeout() # If a timeout is set, clean it
if not self.transport.is_closing():
if not self.transport.is_closing(): # Attribute may alreade not be set
self.transport.close()
except Exception: # nosec: best effort
pass # Ignore errors
except AttributeError: # not initialized transport, fine...
pass
except Exception as e: # nosec: best effort
logger.error('ERROR closing connection: %s', e)
def notify_end(self):
if self.notify_ticket:

View File

@ -245,3 +245,25 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
# Read response
data = await creader.read(1024)
self.assertEqual(data, b'', f'Tunnel host: {tunnel_host}, server host: {host}')
async def test_tunnel_invalid_ssl_handshake(self) -> None:
for tunnel_host in ('127.0.0.1', '::1'):
async with tuntools.create_tunnel_proc(
tunnel_host,
7779,
'localhost',
17222, # Any non used port will do the trick
) as (cfg, _):
async with tuntools.open_tunnel_client(cfg, skip_ssl=True) as (
creader,
cwriter
):
cwriter.write(consts.COMMAND_OPEN) # Will fail, not ssl connection, this is invalid in fact
await cwriter.drain()
# Read response, shoub be empty and at_eof
data = await creader.read(1024)
self.assertEqual(data, b'')
self.assertTrue(creader.at_eof())

View File

@ -206,7 +206,9 @@ async def create_tunnel_proc(
task = asyncio.create_task(udstunnel.tunnel_proc_async(other_end, cfg, global_stats.ns))
# Server listening for connections
server_socket = socket.socket(socket.AF_INET6 if ':' in listen_host else socket.AF_INET, socket.SOCK_STREAM)
server_socket = socket.socket(
socket.AF_INET6 if ':' in listen_host else socket.AF_INET, socket.SOCK_STREAM
)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) # Allow reuse of address
server_socket.bind((listen_host, listen_port))
server_socket.listen(8)
@ -218,7 +220,8 @@ async def create_tunnel_proc(
while True:
client, addr = await loop.sock_accept(server_socket)
# Send the socket to the tunnel
own_end.send((client, addr))
own_end.send((client.dup(), addr))
client.close()
except asyncio.CancelledError:
pass # We are closing
except Exception:
@ -226,7 +229,6 @@ async def create_tunnel_proc(
# Close the socket
server_socket.close()
# Create the middleware task
server_task = asyncio.create_task(server())
try:
@ -374,6 +376,7 @@ async def open_tunnel_client(
cfg: 'config.ConfigurationType',
use_tunnel_handshake: bool = False,
local_port: typing.Optional[int] = None,
skip_ssl: bool = False, # Onlt valid if use_tunnel_handshake is False
) -> collections.abc.AsyncGenerator[typing.Tuple[asyncio.StreamReader, asyncio.StreamWriter], None]:
"""opens an ssl socket to the tunnel server"""
loop = asyncio.get_running_loop()
@ -383,9 +386,12 @@ async def open_tunnel_client(
context.verify_mode = ssl.CERT_NONE
if not use_tunnel_handshake:
reader, writer = await asyncio.open_connection(
cfg.listen_address, cfg.listen_port, ssl=context, family=family, ssl_handshake_timeout=1
)
if not skip_ssl:
reader, writer = await asyncio.open_connection(
cfg.listen_address, cfg.listen_port, family=family, ssl=context, ssl_handshake_timeout=1
)
else:
reader, writer = await asyncio.open_connection(cfg.listen_address, cfg.listen_port, family=family)
else:
# Open the socket, send handshake and then upgrade to ssl, non blocking
sock = socket.socket(family, socket.SOCK_STREAM)
@ -465,7 +471,8 @@ def get_correct_ticket(length: int = consts.TICKET_LENGTH, *, prefix: typing.Opt
prefix = prefix or ''
return (
''.join(
random.choice(string.ascii_letters + string.digits) for _ in range(length - len(prefix)) # nosec just for tests
random.choice(string.ascii_letters + string.digits)
for _ in range(length - len(prefix)) # nosec just for tests
).encode()
+ prefix.encode()
)