mirror of
https://github.com/dkmstr/openuds.git
synced 2025-01-05 09:17:54 +03:00
Some mre lints
This commit is contained in:
parent
379ce8a094
commit
99cd7030e0
@ -66,7 +66,7 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
|
|||||||
|
|
||||||
async with tools.AsyncTCPServer(
|
async with tools.AsyncTCPServer(
|
||||||
host=host, port=remote_port, callback=callback, name='client_task'
|
host=host, port=remote_port, callback=callback, name='client_task'
|
||||||
) as server:
|
) as server: # pylint: disable=unused-variable
|
||||||
# Create a random ticket with valid format
|
# Create a random ticket with valid format
|
||||||
ticket = tuntools.get_correct_ticket(prefix=f'bX0bwmb{remote_port}bX0bwmb')
|
ticket = tuntools.get_correct_ticket(prefix=f'bX0bwmb{remote_port}bX0bwmb')
|
||||||
# Open and send handshake
|
# Open and send handshake
|
||||||
@ -140,7 +140,7 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
|
|||||||
async with tuntools.create_fake_broker_server(
|
async with tuntools.create_fake_broker_server(
|
||||||
host,
|
host,
|
||||||
fake_broker_port,
|
fake_broker_port,
|
||||||
response=lambda data: conf.UDS_GET_TICKET_RESPONSE(host, extract_port(data)),
|
response=lambda data: conf.UDS_GET_TICKET_RESPONSE(host, extract_port(data)), # pylint: disable=cell-var-from-loop
|
||||||
) as req_queue:
|
) as req_queue:
|
||||||
if req_queue is None:
|
if req_queue is None:
|
||||||
raise AssertionError('req_queue is None')
|
raise AssertionError('req_queue is None')
|
||||||
@ -151,11 +151,11 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
|
|||||||
wait_for_port=True,
|
wait_for_port=True,
|
||||||
# Tunnel config
|
# Tunnel config
|
||||||
uds_server=url,
|
uds_server=url,
|
||||||
logfile='/tmp/tunnel_test.log',
|
logfile='/tmp/tunnel_test.log', # nosec: Testing file, fine to be in /tmp
|
||||||
loglevel='DEBUG',
|
loglevel='DEBUG',
|
||||||
workers=4,
|
workers=4,
|
||||||
command_timeout=16, # Increase command timeout because heavy load we will create
|
command_timeout=16, # Increase command timeout because heavy load we will create
|
||||||
) as process:
|
) as process: # pylint: disable=unused-variable
|
||||||
# Create a "bunch" of clients
|
# Create a "bunch" of clients
|
||||||
tasks = [
|
tasks = [
|
||||||
asyncio.create_task(self.client_task(host, tunnel_server_port, remote_port + i))
|
asyncio.create_task(self.client_task(host, tunnel_server_port, remote_port + i))
|
||||||
@ -188,10 +188,6 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
|
|||||||
return int(data.split(b'bX0bwmb')[1])
|
return int(data.split(b'bX0bwmb')[1])
|
||||||
|
|
||||||
for host in ('127.0.0.1', '::1'):
|
for host in ('127.0.0.1', '::1'):
|
||||||
if ':' in host:
|
|
||||||
url = f'http://[{host}]:{fake_broker_port}/uds/rest'
|
|
||||||
else:
|
|
||||||
url = f'http://{host}:{fake_broker_port}/uds/rest'
|
|
||||||
|
|
||||||
req_queue = asyncio.Queue() # clear queue
|
req_queue = asyncio.Queue() # clear queue
|
||||||
# Use tunnel proc for testing
|
# Use tunnel proc for testing
|
||||||
@ -199,10 +195,10 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
|
|||||||
async with tuntools.create_tunnel_proc(
|
async with tuntools.create_tunnel_proc(
|
||||||
host,
|
host,
|
||||||
tunnel_server_port,
|
tunnel_server_port,
|
||||||
response=lambda data: conf.UDS_GET_TICKET_RESPONSE(host, extract_port(data)),
|
response=lambda data: conf.UDS_GET_TICKET_RESPONSE(host, extract_port(data)), # pylint: disable=cell-var-from-loop
|
||||||
command_timeout=16, # Increase command timeout because heavy load we will create,
|
command_timeout=16, # Increase command timeout because heavy load we will create,
|
||||||
global_stats=stats_collector,
|
global_stats=stats_collector,
|
||||||
) as (cfg, _):
|
) as _: # (_ is a tuple, but not used here, just the context)
|
||||||
# Create a "bunch" of clients
|
# Create a "bunch" of clients
|
||||||
tasks = [
|
tasks = [
|
||||||
asyncio.create_task(self.client_task(host, tunnel_server_port, remote_port + i))
|
asyncio.create_task(self.client_task(host, tunnel_server_port, remote_port + i))
|
||||||
|
@ -88,7 +88,7 @@ def selfSignedCert(ip: str, use_password: bool = True) -> typing.Tuple[str, str,
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def sslContext(ip: str) -> typing.Tuple[ssl.SSLContext, str, str]:
|
def sslContext() -> typing.Tuple[ssl.SSLContext, str, str]: # pylint: disable=unused-argument
|
||||||
"""Returns an ssl context an the certificate & password for an ip
|
"""Returns an ssl context an the certificate & password for an ip
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -96,13 +96,14 @@ def sslContext(ip: str) -> typing.Tuple[ssl.SSLContext, str, str]:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
typing.Tuple[ssl.SSLContext, str, str]: ssl context, certificate file and password
|
typing.Tuple[ssl.SSLContext, str, str]: ssl context, certificate file and password
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# First, create server cert and key on temp dir
|
# First, create server cert and key on temp dir
|
||||||
tmpdir = tempfile.gettempdir()
|
tmpdir = tempfile.gettempdir()
|
||||||
tmpname = secrets.token_urlsafe(32)
|
tmpname = secrets.token_urlsafe(32)
|
||||||
cert, key, password = selfSignedCert('127.0.0.1')
|
cert, key, password = selfSignedCert('127.0.0.1')
|
||||||
cert_file = f'{tmpdir}/{tmpname}.pem'
|
cert_file = f'{tmpdir}/{tmpname}.pem'
|
||||||
with open(cert_file, 'w') as f:
|
with open(cert_file, 'w', encoding='utf-8') as f:
|
||||||
f.write(key)
|
f.write(key)
|
||||||
f.write(cert)
|
f.write(cert)
|
||||||
# Create SSL context
|
# Create SSL context
|
||||||
@ -115,7 +116,7 @@ def sslContext(ip: str) -> typing.Tuple[ssl.SSLContext, str, str]:
|
|||||||
return ssl_ctx, cert_file, password
|
return ssl_ctx, cert_file, password
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def ssl_context(ip: str) -> typing.Generator[typing.Tuple[ssl.SSLContext, str], None, None]:
|
def ssl_context() -> typing.Generator[typing.Tuple[ssl.SSLContext, str], None, None]:
|
||||||
"""Returns an ssl context for an ip
|
"""Returns an ssl context for an ip
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -125,7 +126,7 @@ def ssl_context(ip: str) -> typing.Generator[typing.Tuple[ssl.SSLContext, str],
|
|||||||
ssl.SSLContext: ssl context
|
ssl.SSLContext: ssl context
|
||||||
"""
|
"""
|
||||||
# First, create server cert and key on temp dir
|
# First, create server cert and key on temp dir
|
||||||
ssl_ctx, cert_file, password = sslContext(ip)
|
ssl_ctx, cert_file, password = sslContext() # pylint: disable=unused-variable
|
||||||
|
|
||||||
yield ssl_ctx, cert_file
|
yield ssl_ctx, cert_file
|
||||||
|
|
||||||
|
@ -85,7 +85,7 @@ def create_config_file(
|
|||||||
'ssl_dhparam': '',
|
'ssl_dhparam': '',
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
values, cfg = fixtures.get_config(
|
values, cfg = fixtures.get_config( # pylint: disable=unused-variable
|
||||||
**values,
|
**values,
|
||||||
)
|
)
|
||||||
# Write config file
|
# Write config file
|
||||||
@ -97,21 +97,20 @@ def create_config_file(
|
|||||||
try:
|
try:
|
||||||
yield cfgfile
|
yield cfgfile
|
||||||
finally:
|
finally:
|
||||||
pass
|
|
||||||
# Remove the files if they exists
|
# Remove the files if they exists
|
||||||
for filename in (cfgfile, cert_file):
|
for filename in (cfgfile, cert_file):
|
||||||
try:
|
try:
|
||||||
os.remove(filename)
|
os.remove(filename)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
logger.warning('Error removing %s', filename)
|
||||||
|
|
||||||
|
|
||||||
@contextlib.asynccontextmanager
|
@contextlib.asynccontextmanager
|
||||||
async def create_tunnel_proc(
|
async def create_tunnel_proc(
|
||||||
listen_host: str,
|
listen_host: str,
|
||||||
listen_port: int,
|
listen_port: int,
|
||||||
remote_host: str = '0.0.0.0', # Not used if response is provided
|
remote_host: str = '0.0.0.0', # nosec: intentionally value, Not used if response is provided
|
||||||
remote_port: int = 0, # Not used if response is provided
|
remote_port: int = 0, # Not used if response is provided
|
||||||
*,
|
*,
|
||||||
response: typing.Optional[
|
response: typing.Optional[
|
||||||
typing.Union[
|
typing.Union[
|
||||||
@ -147,7 +146,7 @@ async def create_tunnel_proc(
|
|||||||
if response is None:
|
if response is None:
|
||||||
response = conf.UDS_GET_TICKET_RESPONSE(remote_host, remote_port)
|
response = conf.UDS_GET_TICKET_RESPONSE(remote_host, remote_port)
|
||||||
|
|
||||||
port = random.randint(20000, 30000)
|
port = random.randint(8000, 58000) # nosec Just a random port
|
||||||
hhost = f'[{listen_host}]' if ':' in listen_host else listen_host
|
hhost = f'[{listen_host}]' if ':' in listen_host else listen_host
|
||||||
args = {
|
args = {
|
||||||
'uds_server': f'http://{hhost}:{port}/uds/rest',
|
'uds_server': f'http://{hhost}:{port}/uds/rest',
|
||||||
@ -159,12 +158,8 @@ async def create_tunnel_proc(
|
|||||||
resp = conf.UDS_GET_TICKET_RESPONSE(remote_host, remote_port)
|
resp = conf.UDS_GET_TICKET_RESPONSE(remote_host, remote_port)
|
||||||
|
|
||||||
@contextlib.asynccontextmanager
|
@contextlib.asynccontextmanager
|
||||||
async def provider() -> collections.abc.AsyncGenerator[
|
async def provider() -> collections.abc.AsyncGenerator[typing.Optional[asyncio.Queue[bytes]], None]:
|
||||||
typing.Optional[asyncio.Queue[bytes]], None
|
async with create_fake_broker_server(listen_host, port, response=response or resp) as queue:
|
||||||
]:
|
|
||||||
async with create_fake_broker_server(
|
|
||||||
listen_host, port, response=response or resp
|
|
||||||
) as queue:
|
|
||||||
try:
|
try:
|
||||||
yield queue
|
yield queue
|
||||||
finally:
|
finally:
|
||||||
@ -173,9 +168,7 @@ async def create_tunnel_proc(
|
|||||||
else:
|
else:
|
||||||
|
|
||||||
@contextlib.asynccontextmanager
|
@contextlib.asynccontextmanager
|
||||||
async def provider() -> collections.abc.AsyncGenerator[
|
async def provider() -> collections.abc.AsyncGenerator[typing.Optional[asyncio.Queue[bytes]], None]:
|
||||||
typing.Optional[asyncio.Queue[bytes]], None
|
|
||||||
]:
|
|
||||||
with mock.patch(
|
with mock.patch(
|
||||||
'uds_tunnel.tunnel.TunnelProtocol._read_from_uds',
|
'uds_tunnel.tunnel.TunnelProtocol._read_from_uds',
|
||||||
new_callable=tools.AsyncMock,
|
new_callable=tools.AsyncMock,
|
||||||
@ -210,9 +203,7 @@ async def create_tunnel_proc(
|
|||||||
udstunnel.do_stop.clear()
|
udstunnel.do_stop.clear()
|
||||||
|
|
||||||
# Create the tunnel task
|
# Create the tunnel task
|
||||||
task = asyncio.create_task(
|
task = asyncio.create_task(udstunnel.tunnel_proc_async(other_end, cfg, global_stats.ns))
|
||||||
udstunnel.tunnel_proc_async(other_end, cfg, global_stats.ns)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create a small asyncio server that reads the handshake,
|
# Create a small asyncio server that reads the handshake,
|
||||||
# and sends the socket to the tunnel_proc_async using the pipe
|
# and sends the socket to the tunnel_proc_async using the pipe
|
||||||
@ -262,12 +253,10 @@ async def create_tunnel_proc(
|
|||||||
try:
|
try:
|
||||||
os.unlink(h.baseFilename)
|
os.unlink(h.baseFilename)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
logger.warning('Could not remove log file %s', h.baseFilename)
|
||||||
|
|
||||||
|
|
||||||
async def create_tunnel_server(
|
async def create_tunnel_server(cfg: 'config.ConfigurationType', context: 'ssl.SSLContext') -> 'asyncio.Server':
|
||||||
cfg: 'config.ConfigurationType', context: 'ssl.SSLContext'
|
|
||||||
) -> 'asyncio.Server':
|
|
||||||
# Create fake proxy
|
# Create fake proxy
|
||||||
proxy = mock.MagicMock()
|
proxy = mock.MagicMock()
|
||||||
proxy.cfg = cfg
|
proxy.cfg = cfg
|
||||||
@ -286,9 +275,7 @@ async def create_tunnel_server(
|
|||||||
cfg.listen_address,
|
cfg.listen_address,
|
||||||
cfg.listen_port,
|
cfg.listen_port,
|
||||||
ssl=context,
|
ssl=context,
|
||||||
family=socket.AF_INET6
|
family=socket.AF_INET6 if cfg.ipv6 or ':' in cfg.listen_address else socket.AF_INET,
|
||||||
if cfg.ipv6 or ':' in cfg.listen_address
|
|
||||||
else socket.AF_INET,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -308,7 +295,7 @@ async def create_test_tunnel(
|
|||||||
) as server:
|
) as server:
|
||||||
# Create a tunnel to localhost 13579
|
# Create a tunnel to localhost 13579
|
||||||
# SSl cert for tunnel server
|
# SSl cert for tunnel server
|
||||||
with certs.ssl_context(server.host) as (ssl_ctx, _):
|
with certs.ssl_context() as (ssl_ctx, _):
|
||||||
_, cfg = fixtures.get_config(
|
_, cfg = fixtures.get_config(
|
||||||
address=server.host,
|
address=server.host,
|
||||||
port=port or 7771,
|
port=port or 7771,
|
||||||
@ -365,10 +352,7 @@ async def create_fake_broker_server(
|
|||||||
else:
|
else:
|
||||||
rr = response or {}
|
rr = response or {}
|
||||||
|
|
||||||
resp: bytes = (
|
resp: bytes = b'HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n\r\n' + json.dumps(rr).encode()
|
||||||
b'HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n\r\n'
|
|
||||||
+ json.dumps(rr).encode()
|
|
||||||
)
|
|
||||||
|
|
||||||
data = b'' # reset data for next
|
data = b'' # reset data for next
|
||||||
# send response
|
# send response
|
||||||
@ -379,7 +363,7 @@ async def create_fake_broker_server(
|
|||||||
|
|
||||||
async with tools.AsyncTCPServer(
|
async with tools.AsyncTCPServer(
|
||||||
host=host, port=port, processor=processor, name='create_fake_broker_server'
|
host=host, port=port, processor=processor, name='create_fake_broker_server'
|
||||||
) as server:
|
) as server: # pylint: disable=unused-variable
|
||||||
try:
|
try:
|
||||||
yield requests
|
yield requests
|
||||||
finally:
|
finally:
|
||||||
@ -391,14 +375,10 @@ async def open_tunnel_client(
|
|||||||
cfg: 'config.ConfigurationType',
|
cfg: 'config.ConfigurationType',
|
||||||
use_tunnel_handshake: bool = False,
|
use_tunnel_handshake: bool = False,
|
||||||
local_port: typing.Optional[int] = None,
|
local_port: typing.Optional[int] = None,
|
||||||
) -> collections.abc.AsyncGenerator[
|
) -> collections.abc.AsyncGenerator[typing.Tuple[asyncio.StreamReader, asyncio.StreamWriter], None]:
|
||||||
typing.Tuple[asyncio.StreamReader, asyncio.StreamWriter], None
|
|
||||||
]:
|
|
||||||
"""opens an ssl socket to the tunnel server"""
|
"""opens an ssl socket to the tunnel server"""
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
family = (
|
family = socket.AF_INET6 if cfg.ipv6 or ':' in cfg.listen_address else socket.AF_INET
|
||||||
socket.AF_INET6 if cfg.ipv6 or ':' in cfg.listen_address else socket.AF_INET
|
|
||||||
)
|
|
||||||
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||||
context.check_hostname = False
|
context.check_hostname = False
|
||||||
context.verify_mode = ssl.CERT_NONE
|
context.verify_mode = ssl.CERT_NONE
|
||||||
@ -475,15 +455,12 @@ async def tunnel_app_runner(
|
|||||||
await process.wait()
|
await process.wait()
|
||||||
|
|
||||||
|
|
||||||
def get_correct_ticket(
|
def get_correct_ticket(length: int = consts.TICKET_LENGTH, *, prefix: typing.Optional[str] = None) -> bytes:
|
||||||
length: int = consts.TICKET_LENGTH, *, prefix: typing.Optional[str] = None
|
|
||||||
) -> bytes:
|
|
||||||
"""Returns a ticket with the correct length"""
|
"""Returns a ticket with the correct length"""
|
||||||
prefix = prefix or ''
|
prefix = prefix or ''
|
||||||
return (
|
return (
|
||||||
''.join(
|
''.join(
|
||||||
random.choice(string.ascii_letters + string.digits)
|
random.choice(string.ascii_letters + string.digits) for _ in range(length - len(prefix)) # nosec just for tests
|
||||||
for _ in range(length - len(prefix))
|
|
||||||
).encode()
|
).encode()
|
||||||
+ prefix.encode()
|
+ prefix.encode()
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user