mirror of
https://github.com/dkmstr/openuds.git
synced 2025-03-12 04:58:34 +03:00
fixed bad ssl handshake on tests
This commit is contained in:
parent
c33c1501f5
commit
2da927d82b
@ -97,6 +97,9 @@ class Proxy:
|
||||
logger.error('ERROR on %s:%s: %s', src_ip, src_port, e)
|
||||
if tun:
|
||||
tun.close_connection()
|
||||
# Also, ensure socket is closed
|
||||
if source:
|
||||
del source
|
||||
|
||||
logger.debug('Proxy finished')
|
||||
|
||||
|
@ -275,10 +275,12 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
def close_connection(self):
|
||||
try:
|
||||
self.clean_timeout() # If a timeout is set, clean it
|
||||
if not self.transport.is_closing():
|
||||
if not self.transport.is_closing(): # Attribute may alreade not be set
|
||||
self.transport.close()
|
||||
except Exception: # nosec: best effort
|
||||
pass # Ignore errors
|
||||
except AttributeError: # not initialized transport, fine...
|
||||
pass
|
||||
except Exception as e: # nosec: best effort
|
||||
logger.error('ERROR closing connection: %s', e)
|
||||
|
||||
def notify_end(self):
|
||||
if self.notify_ticket:
|
||||
|
@ -245,3 +245,25 @@ class TestUDSTunnelMainProc(IsolatedAsyncioTestCase):
|
||||
# Read response
|
||||
data = await creader.read(1024)
|
||||
self.assertEqual(data, b'', f'Tunnel host: {tunnel_host}, server host: {host}')
|
||||
|
||||
async def test_tunnel_invalid_ssl_handshake(self) -> None:
|
||||
for tunnel_host in ('127.0.0.1', '::1'):
|
||||
async with tuntools.create_tunnel_proc(
|
||||
tunnel_host,
|
||||
7779,
|
||||
'localhost',
|
||||
17222, # Any non used port will do the trick
|
||||
) as (cfg, _):
|
||||
async with tuntools.open_tunnel_client(cfg, skip_ssl=True) as (
|
||||
creader,
|
||||
cwriter
|
||||
):
|
||||
cwriter.write(consts.COMMAND_OPEN) # Will fail, not ssl connection, this is invalid in fact
|
||||
|
||||
await cwriter.drain()
|
||||
|
||||
# Read response, shoub be empty and at_eof
|
||||
data = await creader.read(1024)
|
||||
|
||||
self.assertEqual(data, b'')
|
||||
self.assertTrue(creader.at_eof())
|
||||
|
@ -206,7 +206,9 @@ async def create_tunnel_proc(
|
||||
task = asyncio.create_task(udstunnel.tunnel_proc_async(other_end, cfg, global_stats.ns))
|
||||
|
||||
# Server listening for connections
|
||||
server_socket = socket.socket(socket.AF_INET6 if ':' in listen_host else socket.AF_INET, socket.SOCK_STREAM)
|
||||
server_socket = socket.socket(
|
||||
socket.AF_INET6 if ':' in listen_host else socket.AF_INET, socket.SOCK_STREAM
|
||||
)
|
||||
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) # Allow reuse of address
|
||||
server_socket.bind((listen_host, listen_port))
|
||||
server_socket.listen(8)
|
||||
@ -218,7 +220,8 @@ async def create_tunnel_proc(
|
||||
while True:
|
||||
client, addr = await loop.sock_accept(server_socket)
|
||||
# Send the socket to the tunnel
|
||||
own_end.send((client, addr))
|
||||
own_end.send((client.dup(), addr))
|
||||
client.close()
|
||||
except asyncio.CancelledError:
|
||||
pass # We are closing
|
||||
except Exception:
|
||||
@ -226,7 +229,6 @@ async def create_tunnel_proc(
|
||||
# Close the socket
|
||||
server_socket.close()
|
||||
|
||||
|
||||
# Create the middleware task
|
||||
server_task = asyncio.create_task(server())
|
||||
try:
|
||||
@ -374,6 +376,7 @@ async def open_tunnel_client(
|
||||
cfg: 'config.ConfigurationType',
|
||||
use_tunnel_handshake: bool = False,
|
||||
local_port: typing.Optional[int] = None,
|
||||
skip_ssl: bool = False, # Onlt valid if use_tunnel_handshake is False
|
||||
) -> collections.abc.AsyncGenerator[typing.Tuple[asyncio.StreamReader, asyncio.StreamWriter], None]:
|
||||
"""opens an ssl socket to the tunnel server"""
|
||||
loop = asyncio.get_running_loop()
|
||||
@ -383,9 +386,12 @@ async def open_tunnel_client(
|
||||
context.verify_mode = ssl.CERT_NONE
|
||||
|
||||
if not use_tunnel_handshake:
|
||||
reader, writer = await asyncio.open_connection(
|
||||
cfg.listen_address, cfg.listen_port, ssl=context, family=family, ssl_handshake_timeout=1
|
||||
)
|
||||
if not skip_ssl:
|
||||
reader, writer = await asyncio.open_connection(
|
||||
cfg.listen_address, cfg.listen_port, family=family, ssl=context, ssl_handshake_timeout=1
|
||||
)
|
||||
else:
|
||||
reader, writer = await asyncio.open_connection(cfg.listen_address, cfg.listen_port, family=family)
|
||||
else:
|
||||
# Open the socket, send handshake and then upgrade to ssl, non blocking
|
||||
sock = socket.socket(family, socket.SOCK_STREAM)
|
||||
@ -465,7 +471,8 @@ def get_correct_ticket(length: int = consts.TICKET_LENGTH, *, prefix: typing.Opt
|
||||
prefix = prefix or ''
|
||||
return (
|
||||
''.join(
|
||||
random.choice(string.ascii_letters + string.digits) for _ in range(length - len(prefix)) # nosec just for tests
|
||||
random.choice(string.ascii_letters + string.digits)
|
||||
for _ in range(length - len(prefix)) # nosec just for tests
|
||||
).encode()
|
||||
+ prefix.encode()
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user