1
0
mirror of https://github.com/dkmstr/openuds.git synced 2025-03-13 08:58:35 +03:00

minor refactor and fixed mock test

This commit is contained in:
Adolfo Gómez García 2022-12-14 23:34:56 +01:00
parent 651cb5802e
commit 33ed68b2d0
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
4 changed files with 15 additions and 8 deletions

View File

@ -59,7 +59,7 @@ class TunnelTicket(Handler):
def get(self) -> typing.MutableMapping[str, typing.Any]:
"""
Processes get requests, currently none
Processes get requests
"""
logger.debug(
'Tunnel parameters for GET: %s (%s) from %s',

View File

@ -45,11 +45,11 @@ logger = logging.getLogger(__name__)
class Proxy:
cfg: 'config.ConfigurationType'
ns: 'Namespace'
args: 'Namespace'
def __init__(self, cfg: 'config.ConfigurationType', ns: 'Namespace') -> None:
def __init__(self, cfg: 'config.ConfigurationType', args: 'Namespace') -> None:
self.cfg = cfg
self.ns = ns
self.args = args
# Method responsible of proxying requests
async def __call__(self, source: socket.socket, context: 'ssl.SSLContext') -> None:

View File

@ -79,7 +79,7 @@ class TunnelProtocol(asyncio.Protocol):
self.runner = self.do_proxy
else:
self.other_side = self
self.stats_manager = stats.Stats(owner.ns)
self.stats_manager = stats.Stats(owner.args)
self.counter = self.stats_manager.as_sent_counter()
self.runner = self.do_command
@ -174,7 +174,7 @@ class TunnelProtocol(asyncio.Protocol):
self.transport.write(b'FORBIDDEN')
return
data = stats.GlobalStats.get_stats(self.owner.ns)
data = stats.GlobalStats.get_stats(self.owner.args)
for v in data:
logger.debug('SENDING %s', v)

View File

@ -28,6 +28,7 @@
'''
Author: Adolfo Gómez, dkmaster at dkmon dot com
'''
import typing
import string
import random
import aiohttp
@ -48,6 +49,11 @@ UDS_GET_TICKET_RESPONSE = {
CALLER_HOST = ('host', 12345)
REMOTE_HOST = ('127.0.0.1', 54876)
def uds_response(_, ticket: bytes, msg: str, queryParams: typing.Optional[typing.Mapping[str, str]] = None) -> typing.Dict[str, typing.Any]:
if msg == 'stop':
return {}
return UDS_GET_TICKET_RESPONSE
class TestTunnel(IsolatedAsyncioTestCase):
async def test_get_ticket_from_uds(self) -> None:
@ -58,7 +64,8 @@ class TestTunnel(IsolatedAsyncioTestCase):
'uds_tunnel.tunnel.TunnelProtocol._readFromUDS',
new_callable=tools.AsyncMock,
) as m:
m.return_value = UDS_GET_TICKET_RESPONSE
m.side_effect = uds_response
#m.return_value = UDS_GET_TICKET_RESPONSE
for i in range(0, 100):
ticket = ''.join(
random.choices(
@ -101,7 +108,7 @@ class TestTunnel(IsolatedAsyncioTestCase):
'uds_tunnel.tunnel.TunnelProtocol._readFromUDS',
new_callable=tools.AsyncMock,
) as m:
m.return_value = {}
m.side_effect = uds_response
counter = mock.MagicMock()
counter.sent = 123456789
counter.recv = 987654321