1
0
mirror of https://github.com/dkmstr/openuds.git synced 2025-01-03 01:17:56 +03:00

Some mre lints

This commit is contained in:
Adolfo Gómez García 2023-05-21 05:45:51 +02:00
parent 379ce8a094
commit 99cd7030e0
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
3 changed files with 33 additions and 59 deletions

View File

@ -66,7 +66,7 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
async with tools.AsyncTCPServer(
host=host, port=remote_port, callback=callback, name='client_task'
) as server:
) as server: # pylint: disable=unused-variable
# Create a random ticket with valid format
ticket = tuntools.get_correct_ticket(prefix=f'bX0bwmb{remote_port}bX0bwmb')
# Open and send handshake
@ -140,7 +140,7 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
async with tuntools.create_fake_broker_server(
host,
fake_broker_port,
response=lambda data: conf.UDS_GET_TICKET_RESPONSE(host, extract_port(data)),
response=lambda data: conf.UDS_GET_TICKET_RESPONSE(host, extract_port(data)), # pylint: disable=cell-var-from-loop
) as req_queue:
if req_queue is None:
raise AssertionError('req_queue is None')
@ -151,11 +151,11 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
wait_for_port=True,
# Tunnel config
uds_server=url,
logfile='/tmp/tunnel_test.log',
logfile='/tmp/tunnel_test.log', # nosec: Testing file, fine to be in /tmp
loglevel='DEBUG',
workers=4,
command_timeout=16, # Increase command timeout because heavy load we will create
) as process:
) as process: # pylint: disable=unused-variable
# Create a "bunch" of clients
tasks = [
asyncio.create_task(self.client_task(host, tunnel_server_port, remote_port + i))
@ -188,10 +188,6 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
return int(data.split(b'bX0bwmb')[1])
for host in ('127.0.0.1', '::1'):
if ':' in host:
url = f'http://[{host}]:{fake_broker_port}/uds/rest'
else:
url = f'http://{host}:{fake_broker_port}/uds/rest'
req_queue = asyncio.Queue() # clear queue
# Use tunnel proc for testing
@ -199,10 +195,10 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
async with tuntools.create_tunnel_proc(
host,
tunnel_server_port,
response=lambda data: conf.UDS_GET_TICKET_RESPONSE(host, extract_port(data)),
response=lambda data: conf.UDS_GET_TICKET_RESPONSE(host, extract_port(data)), # pylint: disable=cell-var-from-loop
command_timeout=16, # Increase command timeout because heavy load we will create,
global_stats=stats_collector,
) as (cfg, _):
) as _: # (_ is a tuple, but not used here, just the context)
# Create a "bunch" of clients
tasks = [
asyncio.create_task(self.client_task(host, tunnel_server_port, remote_port + i))

View File

@ -88,7 +88,7 @@ def selfSignedCert(ip: str, use_password: bool = True) -> typing.Tuple[str, str,
)
def sslContext(ip: str) -> typing.Tuple[ssl.SSLContext, str, str]:
def sslContext() -> typing.Tuple[ssl.SSLContext, str, str]: # pylint: disable=unused-argument
"""Returns an ssl context an the certificate & password for an ip
Args:
@ -96,13 +96,14 @@ def sslContext(ip: str) -> typing.Tuple[ssl.SSLContext, str, str]:
Returns:
typing.Tuple[ssl.SSLContext, str, str]: ssl context, certificate file and password
"""
# First, create server cert and key on temp dir
tmpdir = tempfile.gettempdir()
tmpname = secrets.token_urlsafe(32)
cert, key, password = selfSignedCert('127.0.0.1')
cert_file = f'{tmpdir}/{tmpname}.pem'
with open(cert_file, 'w') as f:
with open(cert_file, 'w', encoding='utf-8') as f:
f.write(key)
f.write(cert)
# Create SSL context
@ -115,7 +116,7 @@ def sslContext(ip: str) -> typing.Tuple[ssl.SSLContext, str, str]:
return ssl_ctx, cert_file, password
@contextlib.contextmanager
def ssl_context(ip: str) -> typing.Generator[typing.Tuple[ssl.SSLContext, str], None, None]:
def ssl_context() -> typing.Generator[typing.Tuple[ssl.SSLContext, str], None, None]:
"""Returns an ssl context for an ip
Args:
@ -125,7 +126,7 @@ def ssl_context(ip: str) -> typing.Generator[typing.Tuple[ssl.SSLContext, str],
ssl.SSLContext: ssl context
"""
# First, create server cert and key on temp dir
ssl_ctx, cert_file, password = sslContext(ip)
ssl_ctx, cert_file, password = sslContext() # pylint: disable=unused-variable
yield ssl_ctx, cert_file

View File

@ -85,7 +85,7 @@ def create_config_file(
'ssl_dhparam': '',
}
)
values, cfg = fixtures.get_config(
values, cfg = fixtures.get_config( # pylint: disable=unused-variable
**values,
)
# Write config file
@ -97,20 +97,19 @@ def create_config_file(
try:
yield cfgfile
finally:
pass
# Remove the files if they exists
for filename in (cfgfile, cert_file):
try:
os.remove(filename)
except Exception:
pass
logger.warning('Error removing %s', filename)
@contextlib.asynccontextmanager
async def create_tunnel_proc(
listen_host: str,
listen_port: int,
remote_host: str = '0.0.0.0', # Not used if response is provided
remote_host: str = '0.0.0.0', # nosec: intentionally value, Not used if response is provided
remote_port: int = 0, # Not used if response is provided
*,
response: typing.Optional[
@ -147,7 +146,7 @@ async def create_tunnel_proc(
if response is None:
response = conf.UDS_GET_TICKET_RESPONSE(remote_host, remote_port)
port = random.randint(20000, 30000)
port = random.randint(8000, 58000) # nosec Just a random port
hhost = f'[{listen_host}]' if ':' in listen_host else listen_host
args = {
'uds_server': f'http://{hhost}:{port}/uds/rest',
@ -159,12 +158,8 @@ async def create_tunnel_proc(
resp = conf.UDS_GET_TICKET_RESPONSE(remote_host, remote_port)
@contextlib.asynccontextmanager
async def provider() -> collections.abc.AsyncGenerator[
typing.Optional[asyncio.Queue[bytes]], None
]:
async with create_fake_broker_server(
listen_host, port, response=response or resp
) as queue:
async def provider() -> collections.abc.AsyncGenerator[typing.Optional[asyncio.Queue[bytes]], None]:
async with create_fake_broker_server(listen_host, port, response=response or resp) as queue:
try:
yield queue
finally:
@ -173,9 +168,7 @@ async def create_tunnel_proc(
else:
@contextlib.asynccontextmanager
async def provider() -> collections.abc.AsyncGenerator[
typing.Optional[asyncio.Queue[bytes]], None
]:
async def provider() -> collections.abc.AsyncGenerator[typing.Optional[asyncio.Queue[bytes]], None]:
with mock.patch(
'uds_tunnel.tunnel.TunnelProtocol._read_from_uds',
new_callable=tools.AsyncMock,
@ -210,9 +203,7 @@ async def create_tunnel_proc(
udstunnel.do_stop.clear()
# Create the tunnel task
task = asyncio.create_task(
udstunnel.tunnel_proc_async(other_end, cfg, global_stats.ns)
)
task = asyncio.create_task(udstunnel.tunnel_proc_async(other_end, cfg, global_stats.ns))
# Create a small asyncio server that reads the handshake,
# and sends the socket to the tunnel_proc_async using the pipe
@ -262,12 +253,10 @@ async def create_tunnel_proc(
try:
os.unlink(h.baseFilename)
except Exception:
pass
logger.warning('Could not remove log file %s', h.baseFilename)
async def create_tunnel_server(
cfg: 'config.ConfigurationType', context: 'ssl.SSLContext'
) -> 'asyncio.Server':
async def create_tunnel_server(cfg: 'config.ConfigurationType', context: 'ssl.SSLContext') -> 'asyncio.Server':
# Create fake proxy
proxy = mock.MagicMock()
proxy.cfg = cfg
@ -286,9 +275,7 @@ async def create_tunnel_server(
cfg.listen_address,
cfg.listen_port,
ssl=context,
family=socket.AF_INET6
if cfg.ipv6 or ':' in cfg.listen_address
else socket.AF_INET,
family=socket.AF_INET6 if cfg.ipv6 or ':' in cfg.listen_address else socket.AF_INET,
)
@ -308,7 +295,7 @@ async def create_test_tunnel(
) as server:
# Create a tunnel to localhost 13579
# SSl cert for tunnel server
with certs.ssl_context(server.host) as (ssl_ctx, _):
with certs.ssl_context() as (ssl_ctx, _):
_, cfg = fixtures.get_config(
address=server.host,
port=port or 7771,
@ -365,10 +352,7 @@ async def create_fake_broker_server(
else:
rr = response or {}
resp: bytes = (
b'HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n\r\n'
+ json.dumps(rr).encode()
)
resp: bytes = b'HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n\r\n' + json.dumps(rr).encode()
data = b'' # reset data for next
# send response
@ -379,7 +363,7 @@ async def create_fake_broker_server(
async with tools.AsyncTCPServer(
host=host, port=port, processor=processor, name='create_fake_broker_server'
) as server:
) as server: # pylint: disable=unused-variable
try:
yield requests
finally:
@ -391,14 +375,10 @@ async def open_tunnel_client(
cfg: 'config.ConfigurationType',
use_tunnel_handshake: bool = False,
local_port: typing.Optional[int] = None,
) -> collections.abc.AsyncGenerator[
typing.Tuple[asyncio.StreamReader, asyncio.StreamWriter], None
]:
) -> collections.abc.AsyncGenerator[typing.Tuple[asyncio.StreamReader, asyncio.StreamWriter], None]:
"""opens an ssl socket to the tunnel server"""
loop = asyncio.get_running_loop()
family = (
socket.AF_INET6 if cfg.ipv6 or ':' in cfg.listen_address else socket.AF_INET
)
family = socket.AF_INET6 if cfg.ipv6 or ':' in cfg.listen_address else socket.AF_INET
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
context.check_hostname = False
context.verify_mode = ssl.CERT_NONE
@ -475,15 +455,12 @@ async def tunnel_app_runner(
await process.wait()
def get_correct_ticket(
length: int = consts.TICKET_LENGTH, *, prefix: typing.Optional[str] = None
) -> bytes:
def get_correct_ticket(length: int = consts.TICKET_LENGTH, *, prefix: typing.Optional[str] = None) -> bytes:
"""Returns a ticket with the correct length"""
prefix = prefix or ''
return (
''.join(
random.choice(string.ascii_letters + string.digits)
for _ in range(length - len(prefix))
random.choice(string.ascii_letters + string.digits) for _ in range(length - len(prefix)) # nosec just for tests
).encode()
+ prefix.encode()
)