mirror of
https://github.com/dkmstr/openuds.git
synced 2025-03-12 04:58:34 +03:00
backported 4.0 version
This commit is contained in:
parent
2189267358
commit
0a15f7bdce
@ -0,0 +1,7 @@
|
||||
"""
|
||||
Copyright (c) 2023 Adolfo Gómez García <dkmaster@dkmon.com>
|
||||
|
||||
This software is released under the MIT License.
|
||||
https://opensource.org/licenses/MIT
|
||||
"""
|
||||
|
@ -52,7 +52,7 @@ INTERVAL = 2 # Interval in seconds between stats update
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class StatsSingleCounter:
|
||||
def __init__(self, parent: 'Stats', for_receiving=True) -> None:
|
||||
def __init__(self, parent: 'StatsManager', for_receiving=True) -> None:
|
||||
if for_receiving:
|
||||
self.adder = parent.add_recv
|
||||
else:
|
||||
@ -63,7 +63,7 @@ class StatsSingleCounter:
|
||||
return self
|
||||
|
||||
|
||||
class Stats:
|
||||
class StatsManager:
|
||||
ns: 'Namespace'
|
||||
last_sent: int
|
||||
sent: int
|
||||
@ -100,9 +100,11 @@ class Stats:
|
||||
self.sent += size
|
||||
self.update()
|
||||
|
||||
@property
|
||||
def as_sent_counter(self) -> 'StatsSingleCounter':
|
||||
return StatsSingleCounter(self, False)
|
||||
|
||||
@property
|
||||
def as_recv_counter(self) -> 'StatsSingleCounter':
|
||||
return StatsSingleCounter(self, True)
|
||||
|
||||
|
@ -35,9 +35,8 @@ import socket
|
||||
|
||||
import aiohttp
|
||||
|
||||
from . import consts
|
||||
from . import config
|
||||
from . import stats
|
||||
from . import consts, config, stats, tunnel_client
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -47,31 +46,36 @@ if typing.TYPE_CHECKING:
|
||||
|
||||
# Protocol
|
||||
class TunnelProtocol(asyncio.Protocol):
|
||||
# Transport and other side of tunnel
|
||||
# owner Proxy class
|
||||
owner: 'proxy.Proxy'
|
||||
|
||||
# Transport and client
|
||||
transport: 'asyncio.transports.Transport'
|
||||
other_side: 'TunnelProtocol'
|
||||
# Current state
|
||||
client: typing.Optional['tunnel_client.TunnelClientProtocol']
|
||||
|
||||
# Current state, could be:
|
||||
# - do_command: Waiting for command
|
||||
# - do_proxy: Proxying data
|
||||
runner: typing.Any # In fact, typing.Callable[[bytes], None], but mypy complains on checking variables that are callables on classes
|
||||
# Command buffer
|
||||
cmd: bytes
|
||||
|
||||
# Ticket
|
||||
notify_ticket: bytes # Only exists on "slave" transport (that is, tunnel from us to remote machine)
|
||||
# owner Proxy class
|
||||
owner: 'proxy.Proxy'
|
||||
notify_ticket: bytes # Only exists when we have created the client connection
|
||||
|
||||
# source of connection
|
||||
source: typing.Tuple[str, int]
|
||||
# and destination
|
||||
destination: typing.Tuple[str, int]
|
||||
|
||||
# Counters & stats related
|
||||
stats_manager: stats.Stats
|
||||
# counter
|
||||
counter: stats.StatsSingleCounter
|
||||
stats_manager: stats.StatsManager
|
||||
|
||||
# If there is a timeout task running
|
||||
timeout_task: typing.Optional[asyncio.Task] = None
|
||||
|
||||
is_server_side: bool
|
||||
|
||||
def __init__(
|
||||
self, owner: 'proxy.Proxy', other_side: typing.Optional['TunnelProtocol'] = None
|
||||
self, owner: 'proxy.Proxy'
|
||||
) -> None:
|
||||
# If no other side is given, we are the server part
|
||||
super().__init__()
|
||||
@ -84,22 +88,13 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
|
||||
# If other_side is given, we are the client part (that is, the tunnel from us to remote machine)
|
||||
# In this case, only do_proxy is used
|
||||
if other_side:
|
||||
self.other_side = other_side
|
||||
self.is_server_side = False
|
||||
self.stats_manager = other_side.stats_manager
|
||||
self.counter = self.stats_manager.as_recv_counter()
|
||||
self.runner = self.do_proxy
|
||||
else: # We are the server part, that is the tunnel from client machine to us
|
||||
self.other_side = self
|
||||
self.is_server_side = True
|
||||
self.stats_manager = stats.Stats(owner.ns)
|
||||
self.counter = self.stats_manager.as_sent_counter()
|
||||
# We start processing command
|
||||
# After command, we can process stats or do_proxy, that is the "normal" operation
|
||||
self.runner = self.do_command
|
||||
# Set starting timeout task, se we dont get hunged on connections without data (or insufficient data)
|
||||
self.set_timeout(self.owner.cfg.command_timeout)
|
||||
self.client = None
|
||||
self.stats_manager = stats.StatsManager(owner.ns)
|
||||
# We start processing command
|
||||
# After command, we can process stats or do_proxy, that is the "normal" operation
|
||||
self.runner = self.do_command
|
||||
# Set starting timeout task, se we dont get hunged on connections without data (or insufficient data)
|
||||
self.set_timeout(self.owner.cfg.command_timeout)
|
||||
|
||||
def process_open(self) -> None:
|
||||
# Open Command has the ticket behind it
|
||||
@ -120,7 +115,7 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
async def open_other_side() -> None:
|
||||
async def open_client() -> None:
|
||||
try:
|
||||
result = await TunnelProtocol.get_ticket_from_uds(
|
||||
self.owner.cfg, ticket, self.source
|
||||
@ -148,13 +143,12 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
or (self.owner.cfg.ipv6 and not '.' in self.destination[0])
|
||||
else socket.AF_INET
|
||||
)
|
||||
(_, protocol) = await loop.create_connection(
|
||||
lambda: TunnelProtocol(self.owner, self),
|
||||
(_, self.client) = await loop.create_connection(
|
||||
lambda: tunnel_client.TunnelClientProtocol(self),
|
||||
self.destination[0],
|
||||
self.destination[1],
|
||||
family=family,
|
||||
)
|
||||
self.other_side = typing.cast('TunnelProtocol', protocol)
|
||||
|
||||
# Resume reading
|
||||
self.transport.resume_reading()
|
||||
@ -165,7 +159,7 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
self.close_connection()
|
||||
|
||||
# add open other side to the loop
|
||||
loop.create_task(open_other_side())
|
||||
loop.create_task(open_client())
|
||||
# From now, proxy connection
|
||||
self.runner = self.do_proxy
|
||||
|
||||
@ -266,22 +260,20 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
# if not enough data to process command, wait for more
|
||||
|
||||
def do_proxy(self, data: bytes) -> None:
|
||||
self.counter.add(len(data))
|
||||
# do_proxy will only be called if other_side is set to the other side of the tunnel
|
||||
self.other_side.transport.write(data)
|
||||
self.stats_manager.as_sent_counter.add(len(data))
|
||||
# do_proxy will only be called if other_side is set to the other side of the tunnel, no None is possible
|
||||
typing.cast('tunnel_client.TunnelClientProtocol', self.client).send(data)
|
||||
|
||||
# inherited from asyncio.Protocol
|
||||
def send(self, data: bytes) -> None:
|
||||
self.stats_manager.as_recv_counter.add(len(data))
|
||||
self.transport.write(data)
|
||||
|
||||
def connection_made(self, transport: 'asyncio.transports.BaseTransport') -> None:
|
||||
logger.debug('Connection made: %s', transport.get_extra_info('peername'))
|
||||
|
||||
# We know for sure that the transport is a Transport.
|
||||
self.transport = typing.cast('asyncio.transports.Transport', transport)
|
||||
self.cmd = b''
|
||||
self.source = self.transport.get_extra_info('peername')
|
||||
|
||||
def data_received(self, data: bytes):
|
||||
self.runner(data) # send data to current runner (command or proxy)
|
||||
def close_connection(self):
|
||||
try:
|
||||
if not self.transport.is_closing():
|
||||
self.transport.close()
|
||||
except Exception:
|
||||
pass # Ignore errors
|
||||
|
||||
def notify_end(self):
|
||||
if self.notify_ticket:
|
||||
@ -300,24 +292,33 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
)
|
||||
)
|
||||
self.notify_ticket = b'' # Clean up so no more notifications
|
||||
|
||||
if self.other_side is self: # no other side, simple connection log
|
||||
else:
|
||||
logger.info('TERMINATED %s', self.pretty_source())
|
||||
|
||||
if self.is_server_side:
|
||||
self.stats_manager.close()
|
||||
self.owner.finished.set()
|
||||
self.stats_manager.close()
|
||||
self.owner.finished.set()
|
||||
|
||||
def connection_made(self, transport: 'asyncio.transports.BaseTransport') -> None:
|
||||
logger.debug('Connection made: %s', transport.get_extra_info('peername'))
|
||||
|
||||
# We know for sure that the transport is a Transport.
|
||||
self.transport = typing.cast('asyncio.transports.Transport', transport)
|
||||
self.cmd = b''
|
||||
self.source = self.transport.get_extra_info('peername')
|
||||
|
||||
def data_received(self, data: bytes):
|
||||
self.runner(data) # send data to current runner (command or proxy)
|
||||
|
||||
def connection_lost(self, exc: typing.Optional[Exception]) -> None:
|
||||
# Ensure close other side if not server_side
|
||||
try:
|
||||
self.other_side.transport.close()
|
||||
except Exception:
|
||||
pass
|
||||
if self.client:
|
||||
self.client.close_connection()
|
||||
|
||||
self.notify_end()
|
||||
|
||||
# helpers
|
||||
# *****************
|
||||
# * Helpers *
|
||||
# *****************
|
||||
@staticmethod
|
||||
def pretty_address(address: typing.Tuple[str, int]) -> str:
|
||||
if ':' in address[0]:
|
||||
@ -331,11 +332,6 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
def pretty_destination(self) -> str:
|
||||
return TunnelProtocol.pretty_address(self.destination)
|
||||
|
||||
def close_connection(self):
|
||||
try:
|
||||
self.transport.close()
|
||||
except Exception:
|
||||
pass # Ignore errors
|
||||
|
||||
@staticmethod
|
||||
async def _read_from_uds(
|
||||
@ -387,11 +383,11 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
|
||||
@staticmethod
|
||||
async def notify_end_to_uds(
|
||||
cfg: config.ConfigurationType, ticket: bytes, counter: stats.Stats
|
||||
cfg: config.ConfigurationType, ticket: bytes, stats_mngr: stats.StatsManager
|
||||
) -> None:
|
||||
await TunnelProtocol._read_from_uds(
|
||||
cfg,
|
||||
ticket,
|
||||
'stop',
|
||||
{'sent': str(counter.sent), 'recv': str(counter.recv)},
|
||||
{'sent': str(stats_mngr.sent), 'recv': str(stats_mngr.recv)},
|
||||
)
|
||||
|
88
tunnel-server/src/uds_tunnel/tunnel_client.py
Normal file
88
tunnel-server/src/uds_tunnel/tunnel_client.py
Normal file
@ -0,0 +1,88 @@
|
||||
"""
|
||||
Copyright (c) 2023 Adolfo Gómez García <dkmaster@dkmon.com>
|
||||
|
||||
This software is released under the MIT License.
|
||||
https://opensource.org/licenses/MIT
|
||||
"""
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright (c) 2022 Virtual Cable S.L.U.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification,
|
||||
# are permitted provided that the following conditions are met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright notice,
|
||||
# this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
# * Neither the name of Virtual Cable S.L. nor the names of its contributors
|
||||
# may be used to endorse or promote products derived from this software
|
||||
# without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
'''
|
||||
Author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
'''
|
||||
import asyncio
|
||||
import typing
|
||||
import logging
|
||||
|
||||
from . import consts, config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from . import tunnel, stats
|
||||
|
||||
|
||||
# Protocol
|
||||
class TunnelClientProtocol(asyncio.Protocol):
|
||||
# Transport and other side of tunnel
|
||||
transport: 'asyncio.transports.Transport'
|
||||
receiver: 'tunnel.TunnelProtocol'
|
||||
destination: typing.Tuple[str, int]
|
||||
|
||||
def __init__(
|
||||
self, receiver: 'tunnel.TunnelProtocol'
|
||||
) -> None:
|
||||
# If no other side is given, we are the server part
|
||||
super().__init__()
|
||||
# transport is undefined until connection_made is called
|
||||
self.receiver = receiver
|
||||
self.notify_ticket = b''
|
||||
self.destination = ('', 0)
|
||||
|
||||
def data_received(self, data: bytes):
|
||||
self.receiver.send(data)
|
||||
|
||||
def connection_made(self, transport: 'asyncio.transports.BaseTransport') -> None:
|
||||
self.transport = typing.cast('asyncio.transports.Transport', transport)
|
||||
|
||||
def connection_lost(self, exc: typing.Optional[Exception]) -> None:
|
||||
# Ensure close other side if not server_side
|
||||
try:
|
||||
self.receiver.close_connection()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def send(self, data: bytes):
|
||||
self.transport.write(data)
|
||||
|
||||
def close_connection(self):
|
||||
try:
|
||||
if not self.transport.is_closing():
|
||||
self.transport.close()
|
||||
except Exception:
|
||||
pass # Ignore errors
|
@ -189,8 +189,10 @@ async def tunnel_proc_async(
|
||||
logger.debug('Out of loop, stopping tasks: %s, running: %s', tasks, do_stop.is_set())
|
||||
|
||||
# If any task is still running, cancel it
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
asyncio.gather(*tasks, return_exceptions=True).cancel()
|
||||
|
||||
# for task in tasks:
|
||||
# task.cancel()
|
||||
|
||||
# Wait for all tasks to finish
|
||||
await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED)
|
||||
|
Loading…
x
Reference in New Issue
Block a user