mirror of
https://github.com/dkmstr/openuds.git
synced 2025-03-13 08:58:35 +03:00
Added full tunnel test
This commit is contained in:
parent
0f5f3df3f0
commit
49dddbfce7
@ -51,7 +51,6 @@ class Proxy:
|
||||
def __init__(self, cfg: 'config.ConfigurationType', ns: 'Namespace') -> None:
|
||||
self.cfg = cfg
|
||||
self.ns = ns
|
||||
self.finished = asyncio.Future() # not done yet
|
||||
|
||||
# Method responsible of proxying requests
|
||||
async def __call__(self, source: socket.socket, context: 'ssl.SSLContext') -> None:
|
||||
@ -71,6 +70,7 @@ class Proxy:
|
||||
loop = asyncio.get_running_loop()
|
||||
# Handshake correct in this point, upgrade the connection to TSL and let
|
||||
# the protocol controller do the rest
|
||||
self.finished = loop.create_future()
|
||||
|
||||
# Upgrade connection to SSL, and use asyncio to handle the rest
|
||||
try:
|
||||
|
@ -66,7 +66,7 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
# counter
|
||||
counter: stats.StatsSingleCounter
|
||||
# If there is a timeout task running
|
||||
timeout_task: typing.Optional[asyncio.Task]
|
||||
timeout_task: typing.Optional[asyncio.Task] = None
|
||||
|
||||
def __init__(
|
||||
self, owner: 'proxy.Proxy', other_side: typing.Optional['TunnelProtocol'] = None
|
||||
@ -83,6 +83,8 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
self.stats_manager = stats.Stats(owner.ns)
|
||||
self.counter = self.stats_manager.as_sent_counter()
|
||||
self.runner = self.do_command
|
||||
# Set starting timeout task, se we dont get hunged on connections without data
|
||||
self.set_timeout(consts.TIMEOUT_COMMAND)
|
||||
|
||||
# transport is undefined until connection_made is called
|
||||
self.cmd = b''
|
||||
@ -90,15 +92,14 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
self.owner = owner
|
||||
self.source = ('', 0)
|
||||
self.destination = ('', 0)
|
||||
self.timeout_task = None
|
||||
|
||||
# Set starting timeout task, se we dont get hunged on connections without data
|
||||
self.set_timeout(consts.TIMEOUT_COMMAND)
|
||||
|
||||
def process_open(self) -> None:
|
||||
# Open Command has the ticket behind it
|
||||
|
||||
if len(self.cmd) < consts.TICKET_LENGTH + consts.COMMAND_LENGTH:
|
||||
# Reactivate timeout, will be deactivated on do_command
|
||||
self.set_timeout(consts.TIMEOUT_COMMAND)
|
||||
return # Wait for more data to complete OPEN command
|
||||
|
||||
# Ticket received, now process it with UDS
|
||||
@ -192,6 +193,7 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
self.close_connection()
|
||||
|
||||
async def timeout(self, wait: int) -> None:
|
||||
""" Timeout can only occur while waiting for a command."""
|
||||
try:
|
||||
await asyncio.sleep(wait)
|
||||
logger.error('TIMEOUT FROM %s', self.pretty_source())
|
||||
@ -213,7 +215,8 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
self.timeout_task = asyncio.create_task(self.timeout(wait))
|
||||
|
||||
def clean_timeout(self) -> None:
|
||||
"""Clean the timeout task if any."""
|
||||
"""Clean the timeout task if any.
|
||||
"""
|
||||
if self.timeout_task:
|
||||
self.timeout_task.cancel()
|
||||
self.timeout_task = None
|
||||
|
@ -235,7 +235,7 @@ def tunnel_main(args: 'argparse.Namespace') -> None:
|
||||
|
||||
# If running as root, and requested drop privileges after port bind
|
||||
if os.getuid() == 0 and cfg.user:
|
||||
logger.debug('Changing to user %s', cfg.user)
|
||||
logger.debug('Changing to user %s', cfg.user)
|
||||
pwu = pwd.getpwnam(cfg.user)
|
||||
# os.setgid(pwu.pw_gid)
|
||||
os.setuid(pwu.pw_uid)
|
||||
@ -271,7 +271,7 @@ def tunnel_main(args: 'argparse.Namespace') -> None:
|
||||
|
||||
prcs = processes.Processes(tunnel_proc_async, cfg, stats_collector.ns)
|
||||
|
||||
running.set()
|
||||
running.set() # Signal we are running
|
||||
with ThreadPoolExecutor(max_workers=256) as executor:
|
||||
try:
|
||||
while running.is_set():
|
||||
|
@ -31,12 +31,13 @@ Author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
import typing
|
||||
import random
|
||||
import asyncio
|
||||
import string
|
||||
import logging
|
||||
from unittest import IsolatedAsyncioTestCase, mock
|
||||
|
||||
from uds_tunnel import consts
|
||||
|
||||
from .utils import tuntools
|
||||
from .utils import tuntools, tools
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -49,25 +50,145 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
|
||||
|
||||
async def test_tunnel_fail_cmd(self) -> None:
|
||||
consts.TIMEOUT_COMMAND = 0.1 # type: ignore # timeout is a final variable, but we need to change it for testing speed
|
||||
for i in range(0, 100, 10):
|
||||
# Set timeout to 1 seconds
|
||||
bad_cmd = bytes(random.randint(0, 255) for _ in range(i)) # Some garbage
|
||||
logger.info(f'Testing invalid command with {bad_cmd!r}')
|
||||
for host in ('127.0.0.1', '::1'):
|
||||
# Remote is not really important in this tests, will fail before using it
|
||||
async with tuntools.create_tunnel_proc(
|
||||
host, 7777, '127.0.0.1', 12345, workers=1
|
||||
) as cfg:
|
||||
# Test on ipv4 and ipv6
|
||||
for host in ('127.0.0.1', '::1'):
|
||||
# Remote is not really important in this tests, will fail before using it
|
||||
async with tuntools.create_tunnel_proc(
|
||||
host,
|
||||
7777,
|
||||
'127.0.0.1',
|
||||
12345,
|
||||
) as cfg:
|
||||
for i in range(0, 8192, 128):
|
||||
# Set timeout to 1 seconds
|
||||
bad_cmd = bytes(
|
||||
random.randint(0, 255) for _ in range(i)
|
||||
) # Some garbage
|
||||
logger.info(f'Testing invalid command with {bad_cmd!r}')
|
||||
# On full, we need the handshake to be done, before connecting
|
||||
async with tuntools.open_tunnel_client(cfg, use_tunnel_handshake=True) as (creader, cwriter):
|
||||
# Our "test" server will simple "eat" the handshake, but we need to do it
|
||||
async with tuntools.open_tunnel_client(
|
||||
cfg, use_tunnel_handshake=True
|
||||
) as (creader, cwriter):
|
||||
cwriter.write(bad_cmd)
|
||||
await cwriter.drain()
|
||||
# Read response
|
||||
# Read response
|
||||
data = await creader.read(1024)
|
||||
# if len(bad_cmd) < consts.COMMAND_LENGTH, response will be RESPONSE_ERROR_TIMEOUT
|
||||
if len(bad_cmd) >= consts.COMMAND_LENGTH:
|
||||
self.assertEqual(data, consts.RESPONSE_ERROR_COMMAND)
|
||||
else:
|
||||
self.assertEqual(data, consts.RESPONSE_ERROR_TIMEOUT)
|
||||
|
||||
|
||||
async def test_tunnel_test(self) -> None:
|
||||
for host in ('127.0.0.1', '::1'):
|
||||
# Remote is not really important in this tests, will return ok before using it (this is a TEST command, not OPEN)
|
||||
async with tuntools.create_tunnel_proc(
|
||||
host,
|
||||
7777,
|
||||
'127.0.0.1',
|
||||
12345,
|
||||
) as cfg:
|
||||
for i in range(10): # Several times
|
||||
# On full, we need the handshake to be done, before connecting
|
||||
# Our "test" server will simple "eat" the handshake, but we need to do it
|
||||
async with tuntools.open_tunnel_client(
|
||||
cfg, use_tunnel_handshake=True
|
||||
) as (creader, cwriter):
|
||||
cwriter.write(consts.COMMAND_TEST)
|
||||
await cwriter.drain()
|
||||
# Read response
|
||||
data = await creader.read(1024)
|
||||
self.assertEqual(data, consts.RESPONSE_OK)
|
||||
|
||||
async def test_tunnel_fail_open(self) -> None:
|
||||
consts.TIMEOUT_COMMAND = 0.1 # type: ignore # timeout is a final variable, but we need to change it for testing speed
|
||||
for host in ('127.0.0.1', '::1'):
|
||||
# Remote is NOT important in this tests
|
||||
# create a remote server
|
||||
async with tools.AsyncTCPServer(host=host, port=5445) as server:
|
||||
async with tuntools.create_tunnel_proc(
|
||||
host,
|
||||
7777,
|
||||
server.host,
|
||||
server.port,
|
||||
) as cfg:
|
||||
for i in range(
|
||||
0, consts.TICKET_LENGTH - 1, 4
|
||||
): # All will fail. Any longer will be processed, and mock will return correct don't matter the ticket
|
||||
# Ticket must contain only letters and numbers
|
||||
ticket = ''.join(
|
||||
random.choice(string.ascii_letters + string.digits)
|
||||
for _ in range(i)
|
||||
).encode()
|
||||
# On full, we need the handshake to be done, before connecting
|
||||
# Our "test" server will simple "eat" the handshake, but we need to do it
|
||||
async with tuntools.open_tunnel_client(
|
||||
cfg, use_tunnel_handshake=True
|
||||
) as (creader, cwriter):
|
||||
cwriter.write(consts.COMMAND_OPEN)
|
||||
# fake ticket, consts.TICKET_LENGTH bytes long, letters and numbers. Use a random ticket,
|
||||
cwriter.write(ticket)
|
||||
|
||||
await cwriter.drain()
|
||||
# Read response
|
||||
data = await creader.read(1024)
|
||||
self.assertEqual(data, consts.RESPONSE_ERROR_TIMEOUT)
|
||||
|
||||
async def test_tunnel_open(self) -> None:
|
||||
for host in ('127.0.0.1', '::1'):
|
||||
received: bytes = b''
|
||||
callback_invoked: asyncio.Event = asyncio.Event()
|
||||
|
||||
def callback(data: bytes) -> None:
|
||||
nonlocal received
|
||||
received += data
|
||||
# if data contains EOS marcker ('STREAM_END'), we are done
|
||||
if b'STREAM_END' in data:
|
||||
callback_invoked.set()
|
||||
|
||||
# Remote is important in this tests
|
||||
# create a remote server
|
||||
async with tools.AsyncTCPServer(
|
||||
host=host, port=5445, callback=callback
|
||||
) as server:
|
||||
async with tuntools.create_tunnel_proc(
|
||||
host,
|
||||
7777,
|
||||
server.host,
|
||||
server.port,
|
||||
) as cfg:
|
||||
for i in range(10):
|
||||
# Create a random valid ticket
|
||||
ticket = ''.join(
|
||||
random.choice(string.ascii_letters + string.digits)
|
||||
for _ in range(consts.TICKET_LENGTH)
|
||||
).encode()
|
||||
# On full, we need the handshake to be done, before connecting
|
||||
# Our "test" server will simple "eat" the handshake, but we need to do it
|
||||
async with tuntools.open_tunnel_client(
|
||||
cfg, use_tunnel_handshake=True
|
||||
) as (creader, cwriter):
|
||||
cwriter.write(consts.COMMAND_OPEN)
|
||||
# fake ticket, consts.TICKET_LENGTH bytes long, letters and numbers. Use a random ticket,
|
||||
cwriter.write(ticket)
|
||||
|
||||
await cwriter.drain()
|
||||
# Read response
|
||||
data = await creader.read(1024)
|
||||
self.assertEqual(data, consts.RESPONSE_OK)
|
||||
|
||||
# Data sent will be received by server
|
||||
# One single write will ensure all data is on same packet
|
||||
test_str = b'Some Random Data' + bytes(random.randint(0, 255) for _ in range(512)) + b'STREAM_END'
|
||||
# Clean received data
|
||||
received = b''
|
||||
# And reset event
|
||||
callback_invoked.clear()
|
||||
|
||||
cwriter.write(test_str)
|
||||
await cwriter.drain()
|
||||
|
||||
# Wait for callback to be invoked
|
||||
await callback_invoked.wait()
|
||||
self.assertEqual(received, test_str)
|
||||
|
@ -34,10 +34,9 @@ import io
|
||||
import logging
|
||||
import socket
|
||||
import ssl
|
||||
import tempfile
|
||||
import threading
|
||||
import random
|
||||
import os
|
||||
import typing
|
||||
import json
|
||||
from unittest import mock
|
||||
import multiprocessing
|
||||
|
||||
@ -61,7 +60,7 @@ async def create_tunnel_proc(
|
||||
remote_host: str,
|
||||
remote_port: int,
|
||||
*,
|
||||
workers: int = 1
|
||||
response: typing.Optional[typing.Mapping[str, typing.Any]] = None
|
||||
) -> typing.AsyncGenerator['config.ConfigurationType', None]:
|
||||
# Create the ssl cert
|
||||
cert, key, password = certs.selfSignedCert(listen_host, use_password=False)
|
||||
@ -76,28 +75,35 @@ async def create_tunnel_proc(
|
||||
address=listen_host,
|
||||
port=listen_port,
|
||||
ipv6=':' in listen_host,
|
||||
loglevel='DEBUG',
|
||||
ssl_certificate=cert_file,
|
||||
ssl_certificate_key='',
|
||||
ssl_password=password,
|
||||
ssl_ciphers='',
|
||||
ssl_dhparam='',
|
||||
workers=workers,
|
||||
)
|
||||
args = mock.MagicMock()
|
||||
args.config = io.StringIO(fixtures.TEST_CONFIG.format(**values))
|
||||
args.ipv6 = ':' in listen_host
|
||||
|
||||
return_value: typing.Mapping[str, typing.Any]
|
||||
# Ensure response
|
||||
if response is None:
|
||||
response = conf.UDS_GET_TICKET_RESPONSE(remote_host, remote_port)
|
||||
|
||||
with mock.patch(
|
||||
'uds_tunnel.tunnel.TunnelProtocol._readFromUDS',
|
||||
new_callable=tools.AsyncMock,
|
||||
) as m:
|
||||
m.return_value = conf.UDS_GET_TICKET_RESPONSE(remote_host, remote_port)
|
||||
m.return_value = response
|
||||
|
||||
# Stats collector
|
||||
gs = stats.GlobalStats()
|
||||
# Pipe to send data to tunnel
|
||||
own_end, other_end = multiprocessing.Pipe()
|
||||
|
||||
udstunnel.setup_log(cfg)
|
||||
|
||||
# Set running flag
|
||||
udstunnel.running.set()
|
||||
|
||||
@ -141,6 +147,18 @@ async def create_tunnel_proc(
|
||||
await server.wait_closed()
|
||||
logger.info('Server closed')
|
||||
|
||||
# Ensure log file are removed
|
||||
rootlog = logging.getLogger()
|
||||
for h in rootlog.handlers:
|
||||
if isinstance(h, logging.FileHandler):
|
||||
h.close()
|
||||
# Remove the file if possible, do not fail
|
||||
try:
|
||||
os.unlink(h.baseFilename)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
async def create_tunnel_server(
|
||||
cfg: 'config.ConfigurationType', context: 'ssl.SSLContext'
|
||||
|
Loading…
x
Reference in New Issue
Block a user