1
0
mirror of https://github.com/dkmstr/openuds.git synced 2025-03-13 08:58:35 +03:00

Added full tunnel test

This commit is contained in:
Adolfo Gómez García 2022-12-18 22:16:40 +01:00
parent 0f5f3df3f0
commit 49dddbfce7
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
5 changed files with 169 additions and 27 deletions

View File

@ -51,7 +51,6 @@ class Proxy:
def __init__(self, cfg: 'config.ConfigurationType', ns: 'Namespace') -> None:
self.cfg = cfg
self.ns = ns
self.finished = asyncio.Future() # not done yet
# Method responsible of proxying requests
async def __call__(self, source: socket.socket, context: 'ssl.SSLContext') -> None:
@ -71,6 +70,7 @@ class Proxy:
loop = asyncio.get_running_loop()
# Handshake correct in this point, upgrade the connection to TSL and let
# the protocol controller do the rest
self.finished = loop.create_future()
# Upgrade connection to SSL, and use asyncio to handle the rest
try:

View File

@ -66,7 +66,7 @@ class TunnelProtocol(asyncio.Protocol):
# counter
counter: stats.StatsSingleCounter
# If there is a timeout task running
timeout_task: typing.Optional[asyncio.Task]
timeout_task: typing.Optional[asyncio.Task] = None
def __init__(
self, owner: 'proxy.Proxy', other_side: typing.Optional['TunnelProtocol'] = None
@ -83,6 +83,8 @@ class TunnelProtocol(asyncio.Protocol):
self.stats_manager = stats.Stats(owner.ns)
self.counter = self.stats_manager.as_sent_counter()
self.runner = self.do_command
# Set starting timeout task, se we dont get hunged on connections without data
self.set_timeout(consts.TIMEOUT_COMMAND)
# transport is undefined until connection_made is called
self.cmd = b''
@ -90,15 +92,14 @@ class TunnelProtocol(asyncio.Protocol):
self.owner = owner
self.source = ('', 0)
self.destination = ('', 0)
self.timeout_task = None
# Set starting timeout task, se we dont get hunged on connections without data
self.set_timeout(consts.TIMEOUT_COMMAND)
def process_open(self) -> None:
# Open Command has the ticket behind it
if len(self.cmd) < consts.TICKET_LENGTH + consts.COMMAND_LENGTH:
# Reactivate timeout, will be deactivated on do_command
self.set_timeout(consts.TIMEOUT_COMMAND)
return # Wait for more data to complete OPEN command
# Ticket received, now process it with UDS
@ -192,6 +193,7 @@ class TunnelProtocol(asyncio.Protocol):
self.close_connection()
async def timeout(self, wait: int) -> None:
""" Timeout can only occur while waiting for a command."""
try:
await asyncio.sleep(wait)
logger.error('TIMEOUT FROM %s', self.pretty_source())
@ -213,7 +215,8 @@ class TunnelProtocol(asyncio.Protocol):
self.timeout_task = asyncio.create_task(self.timeout(wait))
def clean_timeout(self) -> None:
"""Clean the timeout task if any."""
"""Clean the timeout task if any.
"""
if self.timeout_task:
self.timeout_task.cancel()
self.timeout_task = None

View File

@ -271,7 +271,7 @@ def tunnel_main(args: 'argparse.Namespace') -> None:
prcs = processes.Processes(tunnel_proc_async, cfg, stats_collector.ns)
running.set()
running.set() # Signal we are running
with ThreadPoolExecutor(max_workers=256) as executor:
try:
while running.is_set():

View File

@ -31,12 +31,13 @@ Author: Adolfo Gómez, dkmaster at dkmon dot com
import typing
import random
import asyncio
import string
import logging
from unittest import IsolatedAsyncioTestCase, mock
from uds_tunnel import consts
from .utils import tuntools
from .utils import tuntools, tools
logger = logging.getLogger(__name__)
@ -49,17 +50,26 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
async def test_tunnel_fail_cmd(self) -> None:
consts.TIMEOUT_COMMAND = 0.1 # type: ignore # timeout is a final variable, but we need to change it for testing speed
for i in range(0, 100, 10):
# Set timeout to 1 seconds
bad_cmd = bytes(random.randint(0, 255) for _ in range(i)) # Some garbage
logger.info(f'Testing invalid command with {bad_cmd!r}')
# Test on ipv4 and ipv6
for host in ('127.0.0.1', '::1'):
# Remote is not really important in this tests, will fail before using it
async with tuntools.create_tunnel_proc(
host, 7777, '127.0.0.1', 12345, workers=1
host,
7777,
'127.0.0.1',
12345,
) as cfg:
for i in range(0, 8192, 128):
# Set timeout to 1 seconds
bad_cmd = bytes(
random.randint(0, 255) for _ in range(i)
) # Some garbage
logger.info(f'Testing invalid command with {bad_cmd!r}')
# On full, we need the handshake to be done, before connecting
async with tuntools.open_tunnel_client(cfg, use_tunnel_handshake=True) as (creader, cwriter):
# Our "test" server will simple "eat" the handshake, but we need to do it
async with tuntools.open_tunnel_client(
cfg, use_tunnel_handshake=True
) as (creader, cwriter):
cwriter.write(bad_cmd)
await cwriter.drain()
# Read response
@ -70,4 +80,115 @@ class TestUDSTunnelApp(IsolatedAsyncioTestCase):
else:
self.assertEqual(data, consts.RESPONSE_ERROR_TIMEOUT)
async def test_tunnel_test(self) -> None:
for host in ('127.0.0.1', '::1'):
# Remote is not really important in this tests, will return ok before using it (this is a TEST command, not OPEN)
async with tuntools.create_tunnel_proc(
host,
7777,
'127.0.0.1',
12345,
) as cfg:
for i in range(10): # Several times
# On full, we need the handshake to be done, before connecting
# Our "test" server will simple "eat" the handshake, but we need to do it
async with tuntools.open_tunnel_client(
cfg, use_tunnel_handshake=True
) as (creader, cwriter):
cwriter.write(consts.COMMAND_TEST)
await cwriter.drain()
# Read response
data = await creader.read(1024)
self.assertEqual(data, consts.RESPONSE_OK)
async def test_tunnel_fail_open(self) -> None:
consts.TIMEOUT_COMMAND = 0.1 # type: ignore # timeout is a final variable, but we need to change it for testing speed
for host in ('127.0.0.1', '::1'):
# Remote is NOT important in this tests
# create a remote server
async with tools.AsyncTCPServer(host=host, port=5445) as server:
async with tuntools.create_tunnel_proc(
host,
7777,
server.host,
server.port,
) as cfg:
for i in range(
0, consts.TICKET_LENGTH - 1, 4
): # All will fail. Any longer will be processed, and mock will return correct don't matter the ticket
# Ticket must contain only letters and numbers
ticket = ''.join(
random.choice(string.ascii_letters + string.digits)
for _ in range(i)
).encode()
# On full, we need the handshake to be done, before connecting
# Our "test" server will simple "eat" the handshake, but we need to do it
async with tuntools.open_tunnel_client(
cfg, use_tunnel_handshake=True
) as (creader, cwriter):
cwriter.write(consts.COMMAND_OPEN)
# fake ticket, consts.TICKET_LENGTH bytes long, letters and numbers. Use a random ticket,
cwriter.write(ticket)
await cwriter.drain()
# Read response
data = await creader.read(1024)
self.assertEqual(data, consts.RESPONSE_ERROR_TIMEOUT)
async def test_tunnel_open(self) -> None:
for host in ('127.0.0.1', '::1'):
received: bytes = b''
callback_invoked: asyncio.Event = asyncio.Event()
def callback(data: bytes) -> None:
nonlocal received
received += data
# if data contains EOS marcker ('STREAM_END'), we are done
if b'STREAM_END' in data:
callback_invoked.set()
# Remote is important in this tests
# create a remote server
async with tools.AsyncTCPServer(
host=host, port=5445, callback=callback
) as server:
async with tuntools.create_tunnel_proc(
host,
7777,
server.host,
server.port,
) as cfg:
for i in range(10):
# Create a random valid ticket
ticket = ''.join(
random.choice(string.ascii_letters + string.digits)
for _ in range(consts.TICKET_LENGTH)
).encode()
# On full, we need the handshake to be done, before connecting
# Our "test" server will simple "eat" the handshake, but we need to do it
async with tuntools.open_tunnel_client(
cfg, use_tunnel_handshake=True
) as (creader, cwriter):
cwriter.write(consts.COMMAND_OPEN)
# fake ticket, consts.TICKET_LENGTH bytes long, letters and numbers. Use a random ticket,
cwriter.write(ticket)
await cwriter.drain()
# Read response
data = await creader.read(1024)
self.assertEqual(data, consts.RESPONSE_OK)
# Data sent will be received by server
# One single write will ensure all data is on same packet
test_str = b'Some Random Data' + bytes(random.randint(0, 255) for _ in range(512)) + b'STREAM_END'
# Clean received data
received = b''
# And reset event
callback_invoked.clear()
cwriter.write(test_str)
await cwriter.drain()
# Wait for callback to be invoked
await callback_invoked.wait()
self.assertEqual(received, test_str)

View File

@ -34,10 +34,9 @@ import io
import logging
import socket
import ssl
import tempfile
import threading
import random
import os
import typing
import json
from unittest import mock
import multiprocessing
@ -61,7 +60,7 @@ async def create_tunnel_proc(
remote_host: str,
remote_port: int,
*,
workers: int = 1
response: typing.Optional[typing.Mapping[str, typing.Any]] = None
) -> typing.AsyncGenerator['config.ConfigurationType', None]:
# Create the ssl cert
cert, key, password = certs.selfSignedCert(listen_host, use_password=False)
@ -76,28 +75,35 @@ async def create_tunnel_proc(
address=listen_host,
port=listen_port,
ipv6=':' in listen_host,
loglevel='DEBUG',
ssl_certificate=cert_file,
ssl_certificate_key='',
ssl_password=password,
ssl_ciphers='',
ssl_dhparam='',
workers=workers,
)
args = mock.MagicMock()
args.config = io.StringIO(fixtures.TEST_CONFIG.format(**values))
args.ipv6 = ':' in listen_host
return_value: typing.Mapping[str, typing.Any]
# Ensure response
if response is None:
response = conf.UDS_GET_TICKET_RESPONSE(remote_host, remote_port)
with mock.patch(
'uds_tunnel.tunnel.TunnelProtocol._readFromUDS',
new_callable=tools.AsyncMock,
) as m:
m.return_value = conf.UDS_GET_TICKET_RESPONSE(remote_host, remote_port)
m.return_value = response
# Stats collector
gs = stats.GlobalStats()
# Pipe to send data to tunnel
own_end, other_end = multiprocessing.Pipe()
udstunnel.setup_log(cfg)
# Set running flag
udstunnel.running.set()
@ -141,6 +147,18 @@ async def create_tunnel_proc(
await server.wait_closed()
logger.info('Server closed')
# Ensure log file are removed
rootlog = logging.getLogger()
for h in rootlog.handlers:
if isinstance(h, logging.FileHandler):
h.close()
# Remove the file if possible, do not fail
try:
os.unlink(h.baseFilename)
except Exception:
pass
async def create_tunnel_server(
cfg: 'config.ConfigurationType', context: 'ssl.SSLContext'