1
0
mirror of https://github.com/dkmstr/openuds.git synced 2025-03-12 04:58:34 +03:00

backported 4.0 version

This commit is contained in:
Adolfo Gómez García 2023-01-05 23:48:36 +01:00
parent 2189267358
commit 0a15f7bdce
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
5 changed files with 166 additions and 71 deletions

View File

@ -0,0 +1,7 @@
"""
Copyright (c) 2023 Adolfo Gómez García <dkmaster@dkmon.com>
This software is released under the MIT License.
https://opensource.org/licenses/MIT
"""

View File

@ -52,7 +52,7 @@ INTERVAL = 2 # Interval in seconds between stats update
logger = logging.getLogger(__name__)
class StatsSingleCounter:
def __init__(self, parent: 'Stats', for_receiving=True) -> None:
def __init__(self, parent: 'StatsManager', for_receiving=True) -> None:
if for_receiving:
self.adder = parent.add_recv
else:
@ -63,7 +63,7 @@ class StatsSingleCounter:
return self
class Stats:
class StatsManager:
ns: 'Namespace'
last_sent: int
sent: int
@ -100,9 +100,11 @@ class Stats:
self.sent += size
self.update()
@property
def as_sent_counter(self) -> 'StatsSingleCounter':
return StatsSingleCounter(self, False)
@property
def as_recv_counter(self) -> 'StatsSingleCounter':
return StatsSingleCounter(self, True)

View File

@ -35,9 +35,8 @@ import socket
import aiohttp
from . import consts
from . import config
from . import stats
from . import consts, config, stats, tunnel_client
logger = logging.getLogger(__name__)
@ -47,31 +46,36 @@ if typing.TYPE_CHECKING:
# Protocol
class TunnelProtocol(asyncio.Protocol):
# Transport and other side of tunnel
# owner Proxy class
owner: 'proxy.Proxy'
# Transport and client
transport: 'asyncio.transports.Transport'
other_side: 'TunnelProtocol'
# Current state
client: typing.Optional['tunnel_client.TunnelClientProtocol']
# Current state, could be:
# - do_command: Waiting for command
# - do_proxy: Proxying data
runner: typing.Any # In fact, typing.Callable[[bytes], None], but mypy complains on checking variables that are callables on classes
# Command buffer
cmd: bytes
# Ticket
notify_ticket: bytes # Only exists on "slave" transport (that is, tunnel from us to remote machine)
# owner Proxy class
owner: 'proxy.Proxy'
notify_ticket: bytes # Only exists when we have created the client connection
# source of connection
source: typing.Tuple[str, int]
# and destination
destination: typing.Tuple[str, int]
# Counters & stats related
stats_manager: stats.Stats
# counter
counter: stats.StatsSingleCounter
stats_manager: stats.StatsManager
# If there is a timeout task running
timeout_task: typing.Optional[asyncio.Task] = None
is_server_side: bool
def __init__(
self, owner: 'proxy.Proxy', other_side: typing.Optional['TunnelProtocol'] = None
self, owner: 'proxy.Proxy'
) -> None:
# If no other side is given, we are the server part
super().__init__()
@ -84,22 +88,13 @@ class TunnelProtocol(asyncio.Protocol):
# If other_side is given, we are the client part (that is, the tunnel from us to remote machine)
# In this case, only do_proxy is used
if other_side:
self.other_side = other_side
self.is_server_side = False
self.stats_manager = other_side.stats_manager
self.counter = self.stats_manager.as_recv_counter()
self.runner = self.do_proxy
else: # We are the server part, that is the tunnel from client machine to us
self.other_side = self
self.is_server_side = True
self.stats_manager = stats.Stats(owner.ns)
self.counter = self.stats_manager.as_sent_counter()
# We start processing command
# After command, we can process stats or do_proxy, that is the "normal" operation
self.runner = self.do_command
# Set starting timeout task, se we dont get hunged on connections without data (or insufficient data)
self.set_timeout(self.owner.cfg.command_timeout)
self.client = None
self.stats_manager = stats.StatsManager(owner.ns)
# We start processing command
# After command, we can process stats or do_proxy, that is the "normal" operation
self.runner = self.do_command
# Set starting timeout task, se we dont get hunged on connections without data (or insufficient data)
self.set_timeout(self.owner.cfg.command_timeout)
def process_open(self) -> None:
# Open Command has the ticket behind it
@ -120,7 +115,7 @@ class TunnelProtocol(asyncio.Protocol):
loop = asyncio.get_running_loop()
async def open_other_side() -> None:
async def open_client() -> None:
try:
result = await TunnelProtocol.get_ticket_from_uds(
self.owner.cfg, ticket, self.source
@ -148,13 +143,12 @@ class TunnelProtocol(asyncio.Protocol):
or (self.owner.cfg.ipv6 and not '.' in self.destination[0])
else socket.AF_INET
)
(_, protocol) = await loop.create_connection(
lambda: TunnelProtocol(self.owner, self),
(_, self.client) = await loop.create_connection(
lambda: tunnel_client.TunnelClientProtocol(self),
self.destination[0],
self.destination[1],
family=family,
)
self.other_side = typing.cast('TunnelProtocol', protocol)
# Resume reading
self.transport.resume_reading()
@ -165,7 +159,7 @@ class TunnelProtocol(asyncio.Protocol):
self.close_connection()
# add open other side to the loop
loop.create_task(open_other_side())
loop.create_task(open_client())
# From now, proxy connection
self.runner = self.do_proxy
@ -266,22 +260,20 @@ class TunnelProtocol(asyncio.Protocol):
# if not enough data to process command, wait for more
def do_proxy(self, data: bytes) -> None:
self.counter.add(len(data))
# do_proxy will only be called if other_side is set to the other side of the tunnel
self.other_side.transport.write(data)
self.stats_manager.as_sent_counter.add(len(data))
# do_proxy will only be called if other_side is set to the other side of the tunnel, no None is possible
typing.cast('tunnel_client.TunnelClientProtocol', self.client).send(data)
# inherited from asyncio.Protocol
def send(self, data: bytes) -> None:
self.stats_manager.as_recv_counter.add(len(data))
self.transport.write(data)
def connection_made(self, transport: 'asyncio.transports.BaseTransport') -> None:
logger.debug('Connection made: %s', transport.get_extra_info('peername'))
# We know for sure that the transport is a Transport.
self.transport = typing.cast('asyncio.transports.Transport', transport)
self.cmd = b''
self.source = self.transport.get_extra_info('peername')
def data_received(self, data: bytes):
self.runner(data) # send data to current runner (command or proxy)
def close_connection(self):
try:
if not self.transport.is_closing():
self.transport.close()
except Exception:
pass # Ignore errors
def notify_end(self):
if self.notify_ticket:
@ -300,24 +292,33 @@ class TunnelProtocol(asyncio.Protocol):
)
)
self.notify_ticket = b'' # Clean up so no more notifications
if self.other_side is self: # no other side, simple connection log
else:
logger.info('TERMINATED %s', self.pretty_source())
if self.is_server_side:
self.stats_manager.close()
self.owner.finished.set()
self.stats_manager.close()
self.owner.finished.set()
def connection_made(self, transport: 'asyncio.transports.BaseTransport') -> None:
logger.debug('Connection made: %s', transport.get_extra_info('peername'))
# We know for sure that the transport is a Transport.
self.transport = typing.cast('asyncio.transports.Transport', transport)
self.cmd = b''
self.source = self.transport.get_extra_info('peername')
def data_received(self, data: bytes):
self.runner(data) # send data to current runner (command or proxy)
def connection_lost(self, exc: typing.Optional[Exception]) -> None:
# Ensure close other side if not server_side
try:
self.other_side.transport.close()
except Exception:
pass
if self.client:
self.client.close_connection()
self.notify_end()
# helpers
# *****************
# * Helpers *
# *****************
@staticmethod
def pretty_address(address: typing.Tuple[str, int]) -> str:
if ':' in address[0]:
@ -331,11 +332,6 @@ class TunnelProtocol(asyncio.Protocol):
def pretty_destination(self) -> str:
return TunnelProtocol.pretty_address(self.destination)
def close_connection(self):
try:
self.transport.close()
except Exception:
pass # Ignore errors
@staticmethod
async def _read_from_uds(
@ -387,11 +383,11 @@ class TunnelProtocol(asyncio.Protocol):
@staticmethod
async def notify_end_to_uds(
cfg: config.ConfigurationType, ticket: bytes, counter: stats.Stats
cfg: config.ConfigurationType, ticket: bytes, stats_mngr: stats.StatsManager
) -> None:
await TunnelProtocol._read_from_uds(
cfg,
ticket,
'stop',
{'sent': str(counter.sent), 'recv': str(counter.recv)},
{'sent': str(stats_mngr.sent), 'recv': str(stats_mngr.recv)},
)

View File

@ -0,0 +1,88 @@
"""
Copyright (c) 2023 Adolfo Gómez García <dkmaster@dkmon.com>
This software is released under the MIT License.
https://opensource.org/licenses/MIT
"""
# -*- coding: utf-8 -*-
#
# Copyright (c) 2022 Virtual Cable S.L.U.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of Virtual Cable S.L. nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
'''
Author: Adolfo Gómez, dkmaster at dkmon dot com
'''
import asyncio
import typing
import logging
from . import consts, config
logger = logging.getLogger(__name__)
if typing.TYPE_CHECKING:
from . import tunnel, stats
# Protocol
class TunnelClientProtocol(asyncio.Protocol):
# Transport and other side of tunnel
transport: 'asyncio.transports.Transport'
receiver: 'tunnel.TunnelProtocol'
destination: typing.Tuple[str, int]
def __init__(
self, receiver: 'tunnel.TunnelProtocol'
) -> None:
# If no other side is given, we are the server part
super().__init__()
# transport is undefined until connection_made is called
self.receiver = receiver
self.notify_ticket = b''
self.destination = ('', 0)
def data_received(self, data: bytes):
self.receiver.send(data)
def connection_made(self, transport: 'asyncio.transports.BaseTransport') -> None:
self.transport = typing.cast('asyncio.transports.Transport', transport)
def connection_lost(self, exc: typing.Optional[Exception]) -> None:
# Ensure close other side if not server_side
try:
self.receiver.close_connection()
except Exception:
pass
def send(self, data: bytes):
self.transport.write(data)
def close_connection(self):
try:
if not self.transport.is_closing():
self.transport.close()
except Exception:
pass # Ignore errors

View File

@ -189,8 +189,10 @@ async def tunnel_proc_async(
logger.debug('Out of loop, stopping tasks: %s, running: %s', tasks, do_stop.is_set())
# If any task is still running, cancel it
for task in tasks:
task.cancel()
asyncio.gather(*tasks, return_exceptions=True).cancel()
# for task in tasks:
# task.cancel()
# Wait for all tasks to finish
await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED)