1
0
mirror of https://github.com/dkmstr/openuds.git synced 2025-01-05 09:17:54 +03:00

Some mre lints

This commit is contained in:
Adolfo Gómez García 2023-05-21 05:45:51 +02:00
parent 379ce8a094
commit 99cd7030e0
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
3 changed files with 33 additions and 59 deletions

View File

@ -66,7 +66,7 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
async with tools.AsyncTCPServer( async with tools.AsyncTCPServer(
host=host, port=remote_port, callback=callback, name='client_task' host=host, port=remote_port, callback=callback, name='client_task'
) as server: ) as server: # pylint: disable=unused-variable
# Create a random ticket with valid format # Create a random ticket with valid format
ticket = tuntools.get_correct_ticket(prefix=f'bX0bwmb{remote_port}bX0bwmb') ticket = tuntools.get_correct_ticket(prefix=f'bX0bwmb{remote_port}bX0bwmb')
# Open and send handshake # Open and send handshake
@ -140,7 +140,7 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
async with tuntools.create_fake_broker_server( async with tuntools.create_fake_broker_server(
host, host,
fake_broker_port, fake_broker_port,
response=lambda data: conf.UDS_GET_TICKET_RESPONSE(host, extract_port(data)), response=lambda data: conf.UDS_GET_TICKET_RESPONSE(host, extract_port(data)), # pylint: disable=cell-var-from-loop
) as req_queue: ) as req_queue:
if req_queue is None: if req_queue is None:
raise AssertionError('req_queue is None') raise AssertionError('req_queue is None')
@ -151,11 +151,11 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
wait_for_port=True, wait_for_port=True,
# Tunnel config # Tunnel config
uds_server=url, uds_server=url,
logfile='/tmp/tunnel_test.log', logfile='/tmp/tunnel_test.log', # nosec: Testing file, fine to be in /tmp
loglevel='DEBUG', loglevel='DEBUG',
workers=4, workers=4,
command_timeout=16, # Increase command timeout because heavy load we will create command_timeout=16, # Increase command timeout because heavy load we will create
) as process: ) as process: # pylint: disable=unused-variable
# Create a "bunch" of clients # Create a "bunch" of clients
tasks = [ tasks = [
asyncio.create_task(self.client_task(host, tunnel_server_port, remote_port + i)) asyncio.create_task(self.client_task(host, tunnel_server_port, remote_port + i))
@ -188,10 +188,6 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
return int(data.split(b'bX0bwmb')[1]) return int(data.split(b'bX0bwmb')[1])
for host in ('127.0.0.1', '::1'): for host in ('127.0.0.1', '::1'):
if ':' in host:
url = f'http://[{host}]:{fake_broker_port}/uds/rest'
else:
url = f'http://{host}:{fake_broker_port}/uds/rest'
req_queue = asyncio.Queue() # clear queue req_queue = asyncio.Queue() # clear queue
# Use tunnel proc for testing # Use tunnel proc for testing
@ -199,10 +195,10 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
async with tuntools.create_tunnel_proc( async with tuntools.create_tunnel_proc(
host, host,
tunnel_server_port, tunnel_server_port,
response=lambda data: conf.UDS_GET_TICKET_RESPONSE(host, extract_port(data)), response=lambda data: conf.UDS_GET_TICKET_RESPONSE(host, extract_port(data)), # pylint: disable=cell-var-from-loop
command_timeout=16, # Increase command timeout because heavy load we will create, command_timeout=16, # Increase command timeout because heavy load we will create,
global_stats=stats_collector, global_stats=stats_collector,
) as (cfg, _): ) as _: # (_ is a tuple, but not used here, just the context)
# Create a "bunch" of clients # Create a "bunch" of clients
tasks = [ tasks = [
asyncio.create_task(self.client_task(host, tunnel_server_port, remote_port + i)) asyncio.create_task(self.client_task(host, tunnel_server_port, remote_port + i))

View File

@ -88,7 +88,7 @@ def selfSignedCert(ip: str, use_password: bool = True) -> typing.Tuple[str, str,
) )
def sslContext(ip: str) -> typing.Tuple[ssl.SSLContext, str, str]: def sslContext() -> typing.Tuple[ssl.SSLContext, str, str]: # pylint: disable=unused-argument
"""Returns an ssl context an the certificate & password for an ip """Returns an ssl context an the certificate & password for an ip
Args: Args:
@ -96,13 +96,14 @@ def sslContext(ip: str) -> typing.Tuple[ssl.SSLContext, str, str]:
Returns: Returns:
typing.Tuple[ssl.SSLContext, str, str]: ssl context, certificate file and password typing.Tuple[ssl.SSLContext, str, str]: ssl context, certificate file and password
""" """
# First, create server cert and key on temp dir # First, create server cert and key on temp dir
tmpdir = tempfile.gettempdir() tmpdir = tempfile.gettempdir()
tmpname = secrets.token_urlsafe(32) tmpname = secrets.token_urlsafe(32)
cert, key, password = selfSignedCert('127.0.0.1') cert, key, password = selfSignedCert('127.0.0.1')
cert_file = f'{tmpdir}/{tmpname}.pem' cert_file = f'{tmpdir}/{tmpname}.pem'
with open(cert_file, 'w') as f: with open(cert_file, 'w', encoding='utf-8') as f:
f.write(key) f.write(key)
f.write(cert) f.write(cert)
# Create SSL context # Create SSL context
@ -115,7 +116,7 @@ def sslContext(ip: str) -> typing.Tuple[ssl.SSLContext, str, str]:
return ssl_ctx, cert_file, password return ssl_ctx, cert_file, password
@contextlib.contextmanager @contextlib.contextmanager
def ssl_context(ip: str) -> typing.Generator[typing.Tuple[ssl.SSLContext, str], None, None]: def ssl_context() -> typing.Generator[typing.Tuple[ssl.SSLContext, str], None, None]:
"""Returns an ssl context for an ip """Returns an ssl context for an ip
Args: Args:
@ -125,7 +126,7 @@ def ssl_context(ip: str) -> typing.Generator[typing.Tuple[ssl.SSLContext, str],
ssl.SSLContext: ssl context ssl.SSLContext: ssl context
""" """
# First, create server cert and key on temp dir # First, create server cert and key on temp dir
ssl_ctx, cert_file, password = sslContext(ip) ssl_ctx, cert_file, password = sslContext() # pylint: disable=unused-variable
yield ssl_ctx, cert_file yield ssl_ctx, cert_file

View File

@ -85,7 +85,7 @@ def create_config_file(
'ssl_dhparam': '', 'ssl_dhparam': '',
} }
) )
values, cfg = fixtures.get_config( values, cfg = fixtures.get_config( # pylint: disable=unused-variable
**values, **values,
) )
# Write config file # Write config file
@ -97,20 +97,19 @@ def create_config_file(
try: try:
yield cfgfile yield cfgfile
finally: finally:
pass
# Remove the files if they exists # Remove the files if they exists
for filename in (cfgfile, cert_file): for filename in (cfgfile, cert_file):
try: try:
os.remove(filename) os.remove(filename)
except Exception: except Exception:
pass logger.warning('Error removing %s', filename)
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def create_tunnel_proc( async def create_tunnel_proc(
listen_host: str, listen_host: str,
listen_port: int, listen_port: int,
remote_host: str = '0.0.0.0', # Not used if response is provided remote_host: str = '0.0.0.0', # nosec: intentionally value, Not used if response is provided
remote_port: int = 0, # Not used if response is provided remote_port: int = 0, # Not used if response is provided
*, *,
response: typing.Optional[ response: typing.Optional[
@ -147,7 +146,7 @@ async def create_tunnel_proc(
if response is None: if response is None:
response = conf.UDS_GET_TICKET_RESPONSE(remote_host, remote_port) response = conf.UDS_GET_TICKET_RESPONSE(remote_host, remote_port)
port = random.randint(20000, 30000) port = random.randint(8000, 58000) # nosec Just a random port
hhost = f'[{listen_host}]' if ':' in listen_host else listen_host hhost = f'[{listen_host}]' if ':' in listen_host else listen_host
args = { args = {
'uds_server': f'http://{hhost}:{port}/uds/rest', 'uds_server': f'http://{hhost}:{port}/uds/rest',
@ -159,12 +158,8 @@ async def create_tunnel_proc(
resp = conf.UDS_GET_TICKET_RESPONSE(remote_host, remote_port) resp = conf.UDS_GET_TICKET_RESPONSE(remote_host, remote_port)
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def provider() -> collections.abc.AsyncGenerator[ async def provider() -> collections.abc.AsyncGenerator[typing.Optional[asyncio.Queue[bytes]], None]:
typing.Optional[asyncio.Queue[bytes]], None async with create_fake_broker_server(listen_host, port, response=response or resp) as queue:
]:
async with create_fake_broker_server(
listen_host, port, response=response or resp
) as queue:
try: try:
yield queue yield queue
finally: finally:
@ -173,9 +168,7 @@ async def create_tunnel_proc(
else: else:
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def provider() -> collections.abc.AsyncGenerator[ async def provider() -> collections.abc.AsyncGenerator[typing.Optional[asyncio.Queue[bytes]], None]:
typing.Optional[asyncio.Queue[bytes]], None
]:
with mock.patch( with mock.patch(
'uds_tunnel.tunnel.TunnelProtocol._read_from_uds', 'uds_tunnel.tunnel.TunnelProtocol._read_from_uds',
new_callable=tools.AsyncMock, new_callable=tools.AsyncMock,
@ -210,9 +203,7 @@ async def create_tunnel_proc(
udstunnel.do_stop.clear() udstunnel.do_stop.clear()
# Create the tunnel task # Create the tunnel task
task = asyncio.create_task( task = asyncio.create_task(udstunnel.tunnel_proc_async(other_end, cfg, global_stats.ns))
udstunnel.tunnel_proc_async(other_end, cfg, global_stats.ns)
)
# Create a small asyncio server that reads the handshake, # Create a small asyncio server that reads the handshake,
# and sends the socket to the tunnel_proc_async using the pipe # and sends the socket to the tunnel_proc_async using the pipe
@ -262,12 +253,10 @@ async def create_tunnel_proc(
try: try:
os.unlink(h.baseFilename) os.unlink(h.baseFilename)
except Exception: except Exception:
pass logger.warning('Could not remove log file %s', h.baseFilename)
async def create_tunnel_server( async def create_tunnel_server(cfg: 'config.ConfigurationType', context: 'ssl.SSLContext') -> 'asyncio.Server':
cfg: 'config.ConfigurationType', context: 'ssl.SSLContext'
) -> 'asyncio.Server':
# Create fake proxy # Create fake proxy
proxy = mock.MagicMock() proxy = mock.MagicMock()
proxy.cfg = cfg proxy.cfg = cfg
@ -286,9 +275,7 @@ async def create_tunnel_server(
cfg.listen_address, cfg.listen_address,
cfg.listen_port, cfg.listen_port,
ssl=context, ssl=context,
family=socket.AF_INET6 family=socket.AF_INET6 if cfg.ipv6 or ':' in cfg.listen_address else socket.AF_INET,
if cfg.ipv6 or ':' in cfg.listen_address
else socket.AF_INET,
) )
@ -308,7 +295,7 @@ async def create_test_tunnel(
) as server: ) as server:
# Create a tunnel to localhost 13579 # Create a tunnel to localhost 13579
# SSl cert for tunnel server # SSl cert for tunnel server
with certs.ssl_context(server.host) as (ssl_ctx, _): with certs.ssl_context() as (ssl_ctx, _):
_, cfg = fixtures.get_config( _, cfg = fixtures.get_config(
address=server.host, address=server.host,
port=port or 7771, port=port or 7771,
@ -365,10 +352,7 @@ async def create_fake_broker_server(
else: else:
rr = response or {} rr = response or {}
resp: bytes = ( resp: bytes = b'HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n\r\n' + json.dumps(rr).encode()
b'HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n\r\n'
+ json.dumps(rr).encode()
)
data = b'' # reset data for next data = b'' # reset data for next
# send response # send response
@ -379,7 +363,7 @@ async def create_fake_broker_server(
async with tools.AsyncTCPServer( async with tools.AsyncTCPServer(
host=host, port=port, processor=processor, name='create_fake_broker_server' host=host, port=port, processor=processor, name='create_fake_broker_server'
) as server: ) as server: # pylint: disable=unused-variable
try: try:
yield requests yield requests
finally: finally:
@ -391,14 +375,10 @@ async def open_tunnel_client(
cfg: 'config.ConfigurationType', cfg: 'config.ConfigurationType',
use_tunnel_handshake: bool = False, use_tunnel_handshake: bool = False,
local_port: typing.Optional[int] = None, local_port: typing.Optional[int] = None,
) -> collections.abc.AsyncGenerator[ ) -> collections.abc.AsyncGenerator[typing.Tuple[asyncio.StreamReader, asyncio.StreamWriter], None]:
typing.Tuple[asyncio.StreamReader, asyncio.StreamWriter], None
]:
"""opens an ssl socket to the tunnel server""" """opens an ssl socket to the tunnel server"""
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
family = ( family = socket.AF_INET6 if cfg.ipv6 or ':' in cfg.listen_address else socket.AF_INET
socket.AF_INET6 if cfg.ipv6 or ':' in cfg.listen_address else socket.AF_INET
)
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
context.check_hostname = False context.check_hostname = False
context.verify_mode = ssl.CERT_NONE context.verify_mode = ssl.CERT_NONE
@ -475,15 +455,12 @@ async def tunnel_app_runner(
await process.wait() await process.wait()
def get_correct_ticket( def get_correct_ticket(length: int = consts.TICKET_LENGTH, *, prefix: typing.Optional[str] = None) -> bytes:
length: int = consts.TICKET_LENGTH, *, prefix: typing.Optional[str] = None
) -> bytes:
"""Returns a ticket with the correct length""" """Returns a ticket with the correct length"""
prefix = prefix or '' prefix = prefix or ''
return ( return (
''.join( ''.join(
random.choice(string.ascii_letters + string.digits) random.choice(string.ascii_letters + string.digits) for _ in range(length - len(prefix)) # nosec just for tests
for _ in range(length - len(prefix))
).encode() ).encode()
+ prefix.encode() + prefix.encode()
) )