mirror of
https://github.com/dkmstr/openuds.git
synced 2025-03-13 08:58:35 +03:00
minor refactor and fixed mock test
This commit is contained in:
parent
651cb5802e
commit
33ed68b2d0
@ -59,7 +59,7 @@ class TunnelTicket(Handler):
|
||||
|
||||
def get(self) -> typing.MutableMapping[str, typing.Any]:
|
||||
"""
|
||||
Processes get requests, currently none
|
||||
Processes get requests
|
||||
"""
|
||||
logger.debug(
|
||||
'Tunnel parameters for GET: %s (%s) from %s',
|
||||
|
@ -45,11 +45,11 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class Proxy:
|
||||
cfg: 'config.ConfigurationType'
|
||||
ns: 'Namespace'
|
||||
args: 'Namespace'
|
||||
|
||||
def __init__(self, cfg: 'config.ConfigurationType', ns: 'Namespace') -> None:
|
||||
def __init__(self, cfg: 'config.ConfigurationType', args: 'Namespace') -> None:
|
||||
self.cfg = cfg
|
||||
self.ns = ns
|
||||
self.args = args
|
||||
|
||||
# Method responsible of proxying requests
|
||||
async def __call__(self, source: socket.socket, context: 'ssl.SSLContext') -> None:
|
||||
|
@ -79,7 +79,7 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
self.runner = self.do_proxy
|
||||
else:
|
||||
self.other_side = self
|
||||
self.stats_manager = stats.Stats(owner.ns)
|
||||
self.stats_manager = stats.Stats(owner.args)
|
||||
self.counter = self.stats_manager.as_sent_counter()
|
||||
self.runner = self.do_command
|
||||
|
||||
@ -174,7 +174,7 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
self.transport.write(b'FORBIDDEN')
|
||||
return
|
||||
|
||||
data = stats.GlobalStats.get_stats(self.owner.ns)
|
||||
data = stats.GlobalStats.get_stats(self.owner.args)
|
||||
|
||||
for v in data:
|
||||
logger.debug('SENDING %s', v)
|
||||
|
@ -28,6 +28,7 @@
|
||||
'''
|
||||
Author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
'''
|
||||
import typing
|
||||
import string
|
||||
import random
|
||||
import aiohttp
|
||||
@ -48,6 +49,11 @@ UDS_GET_TICKET_RESPONSE = {
|
||||
CALLER_HOST = ('host', 12345)
|
||||
REMOTE_HOST = ('127.0.0.1', 54876)
|
||||
|
||||
def uds_response(_, ticket: bytes, msg: str, queryParams: typing.Optional[typing.Mapping[str, str]] = None) -> typing.Dict[str, typing.Any]:
|
||||
if msg == 'stop':
|
||||
return {}
|
||||
|
||||
return UDS_GET_TICKET_RESPONSE
|
||||
|
||||
class TestTunnel(IsolatedAsyncioTestCase):
|
||||
async def test_get_ticket_from_uds(self) -> None:
|
||||
@ -58,7 +64,8 @@ class TestTunnel(IsolatedAsyncioTestCase):
|
||||
'uds_tunnel.tunnel.TunnelProtocol._readFromUDS',
|
||||
new_callable=tools.AsyncMock,
|
||||
) as m:
|
||||
m.return_value = UDS_GET_TICKET_RESPONSE
|
||||
m.side_effect = uds_response
|
||||
#m.return_value = UDS_GET_TICKET_RESPONSE
|
||||
for i in range(0, 100):
|
||||
ticket = ''.join(
|
||||
random.choices(
|
||||
@ -101,7 +108,7 @@ class TestTunnel(IsolatedAsyncioTestCase):
|
||||
'uds_tunnel.tunnel.TunnelProtocol._readFromUDS',
|
||||
new_callable=tools.AsyncMock,
|
||||
) as m:
|
||||
m.return_value = {}
|
||||
m.side_effect = uds_response
|
||||
counter = mock.MagicMock()
|
||||
counter.sent = 123456789
|
||||
counter.recv = 987654321
|
||||
|
Loading…
x
Reference in New Issue
Block a user