mirror of
https://github.com/dkmstr/openuds.git
synced 2025-03-12 04:58:34 +03:00
backported 4.0 version
This commit is contained in:
parent
2189267358
commit
0a15f7bdce
@ -0,0 +1,7 @@
|
|||||||
|
"""
|
||||||
|
Copyright (c) 2023 Adolfo Gómez García <dkmaster@dkmon.com>
|
||||||
|
|
||||||
|
This software is released under the MIT License.
|
||||||
|
https://opensource.org/licenses/MIT
|
||||||
|
"""
|
||||||
|
|
@ -52,7 +52,7 @@ INTERVAL = 2 # Interval in seconds between stats update
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class StatsSingleCounter:
|
class StatsSingleCounter:
|
||||||
def __init__(self, parent: 'Stats', for_receiving=True) -> None:
|
def __init__(self, parent: 'StatsManager', for_receiving=True) -> None:
|
||||||
if for_receiving:
|
if for_receiving:
|
||||||
self.adder = parent.add_recv
|
self.adder = parent.add_recv
|
||||||
else:
|
else:
|
||||||
@ -63,7 +63,7 @@ class StatsSingleCounter:
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
class Stats:
|
class StatsManager:
|
||||||
ns: 'Namespace'
|
ns: 'Namespace'
|
||||||
last_sent: int
|
last_sent: int
|
||||||
sent: int
|
sent: int
|
||||||
@ -100,9 +100,11 @@ class Stats:
|
|||||||
self.sent += size
|
self.sent += size
|
||||||
self.update()
|
self.update()
|
||||||
|
|
||||||
|
@property
|
||||||
def as_sent_counter(self) -> 'StatsSingleCounter':
|
def as_sent_counter(self) -> 'StatsSingleCounter':
|
||||||
return StatsSingleCounter(self, False)
|
return StatsSingleCounter(self, False)
|
||||||
|
|
||||||
|
@property
|
||||||
def as_recv_counter(self) -> 'StatsSingleCounter':
|
def as_recv_counter(self) -> 'StatsSingleCounter':
|
||||||
return StatsSingleCounter(self, True)
|
return StatsSingleCounter(self, True)
|
||||||
|
|
||||||
|
@ -35,9 +35,8 @@ import socket
|
|||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
from . import consts
|
from . import consts, config, stats, tunnel_client
|
||||||
from . import config
|
|
||||||
from . import stats
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -47,31 +46,36 @@ if typing.TYPE_CHECKING:
|
|||||||
|
|
||||||
# Protocol
|
# Protocol
|
||||||
class TunnelProtocol(asyncio.Protocol):
|
class TunnelProtocol(asyncio.Protocol):
|
||||||
# Transport and other side of tunnel
|
# owner Proxy class
|
||||||
|
owner: 'proxy.Proxy'
|
||||||
|
|
||||||
|
# Transport and client
|
||||||
transport: 'asyncio.transports.Transport'
|
transport: 'asyncio.transports.Transport'
|
||||||
other_side: 'TunnelProtocol'
|
client: typing.Optional['tunnel_client.TunnelClientProtocol']
|
||||||
# Current state
|
|
||||||
|
# Current state, could be:
|
||||||
|
# - do_command: Waiting for command
|
||||||
|
# - do_proxy: Proxying data
|
||||||
runner: typing.Any # In fact, typing.Callable[[bytes], None], but mypy complains on checking variables that are callables on classes
|
runner: typing.Any # In fact, typing.Callable[[bytes], None], but mypy complains on checking variables that are callables on classes
|
||||||
# Command buffer
|
# Command buffer
|
||||||
cmd: bytes
|
cmd: bytes
|
||||||
|
|
||||||
# Ticket
|
# Ticket
|
||||||
notify_ticket: bytes # Only exists on "slave" transport (that is, tunnel from us to remote machine)
|
notify_ticket: bytes # Only exists when we have created the client connection
|
||||||
# owner Proxy class
|
|
||||||
owner: 'proxy.Proxy'
|
|
||||||
# source of connection
|
# source of connection
|
||||||
source: typing.Tuple[str, int]
|
source: typing.Tuple[str, int]
|
||||||
|
# and destination
|
||||||
destination: typing.Tuple[str, int]
|
destination: typing.Tuple[str, int]
|
||||||
|
|
||||||
# Counters & stats related
|
# Counters & stats related
|
||||||
stats_manager: stats.Stats
|
stats_manager: stats.StatsManager
|
||||||
# counter
|
|
||||||
counter: stats.StatsSingleCounter
|
|
||||||
# If there is a timeout task running
|
# If there is a timeout task running
|
||||||
timeout_task: typing.Optional[asyncio.Task] = None
|
timeout_task: typing.Optional[asyncio.Task] = None
|
||||||
|
|
||||||
is_server_side: bool
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, owner: 'proxy.Proxy', other_side: typing.Optional['TunnelProtocol'] = None
|
self, owner: 'proxy.Proxy'
|
||||||
) -> None:
|
) -> None:
|
||||||
# If no other side is given, we are the server part
|
# If no other side is given, we are the server part
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -84,17 +88,8 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
|
|
||||||
# If other_side is given, we are the client part (that is, the tunnel from us to remote machine)
|
# If other_side is given, we are the client part (that is, the tunnel from us to remote machine)
|
||||||
# In this case, only do_proxy is used
|
# In this case, only do_proxy is used
|
||||||
if other_side:
|
self.client = None
|
||||||
self.other_side = other_side
|
self.stats_manager = stats.StatsManager(owner.ns)
|
||||||
self.is_server_side = False
|
|
||||||
self.stats_manager = other_side.stats_manager
|
|
||||||
self.counter = self.stats_manager.as_recv_counter()
|
|
||||||
self.runner = self.do_proxy
|
|
||||||
else: # We are the server part, that is the tunnel from client machine to us
|
|
||||||
self.other_side = self
|
|
||||||
self.is_server_side = True
|
|
||||||
self.stats_manager = stats.Stats(owner.ns)
|
|
||||||
self.counter = self.stats_manager.as_sent_counter()
|
|
||||||
# We start processing command
|
# We start processing command
|
||||||
# After command, we can process stats or do_proxy, that is the "normal" operation
|
# After command, we can process stats or do_proxy, that is the "normal" operation
|
||||||
self.runner = self.do_command
|
self.runner = self.do_command
|
||||||
@ -120,7 +115,7 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
|
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
async def open_other_side() -> None:
|
async def open_client() -> None:
|
||||||
try:
|
try:
|
||||||
result = await TunnelProtocol.get_ticket_from_uds(
|
result = await TunnelProtocol.get_ticket_from_uds(
|
||||||
self.owner.cfg, ticket, self.source
|
self.owner.cfg, ticket, self.source
|
||||||
@ -148,13 +143,12 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
or (self.owner.cfg.ipv6 and not '.' in self.destination[0])
|
or (self.owner.cfg.ipv6 and not '.' in self.destination[0])
|
||||||
else socket.AF_INET
|
else socket.AF_INET
|
||||||
)
|
)
|
||||||
(_, protocol) = await loop.create_connection(
|
(_, self.client) = await loop.create_connection(
|
||||||
lambda: TunnelProtocol(self.owner, self),
|
lambda: tunnel_client.TunnelClientProtocol(self),
|
||||||
self.destination[0],
|
self.destination[0],
|
||||||
self.destination[1],
|
self.destination[1],
|
||||||
family=family,
|
family=family,
|
||||||
)
|
)
|
||||||
self.other_side = typing.cast('TunnelProtocol', protocol)
|
|
||||||
|
|
||||||
# Resume reading
|
# Resume reading
|
||||||
self.transport.resume_reading()
|
self.transport.resume_reading()
|
||||||
@ -165,7 +159,7 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
self.close_connection()
|
self.close_connection()
|
||||||
|
|
||||||
# add open other side to the loop
|
# add open other side to the loop
|
||||||
loop.create_task(open_other_side())
|
loop.create_task(open_client())
|
||||||
# From now, proxy connection
|
# From now, proxy connection
|
||||||
self.runner = self.do_proxy
|
self.runner = self.do_proxy
|
||||||
|
|
||||||
@ -266,22 +260,20 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
# if not enough data to process command, wait for more
|
# if not enough data to process command, wait for more
|
||||||
|
|
||||||
def do_proxy(self, data: bytes) -> None:
|
def do_proxy(self, data: bytes) -> None:
|
||||||
self.counter.add(len(data))
|
self.stats_manager.as_sent_counter.add(len(data))
|
||||||
# do_proxy will only be called if other_side is set to the other side of the tunnel
|
# do_proxy will only be called if other_side is set to the other side of the tunnel, no None is possible
|
||||||
self.other_side.transport.write(data)
|
typing.cast('tunnel_client.TunnelClientProtocol', self.client).send(data)
|
||||||
|
|
||||||
# inherited from asyncio.Protocol
|
def send(self, data: bytes) -> None:
|
||||||
|
self.stats_manager.as_recv_counter.add(len(data))
|
||||||
|
self.transport.write(data)
|
||||||
|
|
||||||
def connection_made(self, transport: 'asyncio.transports.BaseTransport') -> None:
|
def close_connection(self):
|
||||||
logger.debug('Connection made: %s', transport.get_extra_info('peername'))
|
try:
|
||||||
|
if not self.transport.is_closing():
|
||||||
# We know for sure that the transport is a Transport.
|
self.transport.close()
|
||||||
self.transport = typing.cast('asyncio.transports.Transport', transport)
|
except Exception:
|
||||||
self.cmd = b''
|
pass # Ignore errors
|
||||||
self.source = self.transport.get_extra_info('peername')
|
|
||||||
|
|
||||||
def data_received(self, data: bytes):
|
|
||||||
self.runner(data) # send data to current runner (command or proxy)
|
|
||||||
|
|
||||||
def notify_end(self):
|
def notify_end(self):
|
||||||
if self.notify_ticket:
|
if self.notify_ticket:
|
||||||
@ -300,24 +292,33 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.notify_ticket = b'' # Clean up so no more notifications
|
self.notify_ticket = b'' # Clean up so no more notifications
|
||||||
|
else:
|
||||||
if self.other_side is self: # no other side, simple connection log
|
|
||||||
logger.info('TERMINATED %s', self.pretty_source())
|
logger.info('TERMINATED %s', self.pretty_source())
|
||||||
|
|
||||||
if self.is_server_side:
|
|
||||||
self.stats_manager.close()
|
self.stats_manager.close()
|
||||||
self.owner.finished.set()
|
self.owner.finished.set()
|
||||||
|
|
||||||
|
def connection_made(self, transport: 'asyncio.transports.BaseTransport') -> None:
|
||||||
|
logger.debug('Connection made: %s', transport.get_extra_info('peername'))
|
||||||
|
|
||||||
|
# We know for sure that the transport is a Transport.
|
||||||
|
self.transport = typing.cast('asyncio.transports.Transport', transport)
|
||||||
|
self.cmd = b''
|
||||||
|
self.source = self.transport.get_extra_info('peername')
|
||||||
|
|
||||||
|
def data_received(self, data: bytes):
|
||||||
|
self.runner(data) # send data to current runner (command or proxy)
|
||||||
|
|
||||||
def connection_lost(self, exc: typing.Optional[Exception]) -> None:
|
def connection_lost(self, exc: typing.Optional[Exception]) -> None:
|
||||||
# Ensure close other side if not server_side
|
# Ensure close other side if not server_side
|
||||||
try:
|
if self.client:
|
||||||
self.other_side.transport.close()
|
self.client.close_connection()
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
self.notify_end()
|
self.notify_end()
|
||||||
|
|
||||||
# helpers
|
# *****************
|
||||||
|
# * Helpers *
|
||||||
|
# *****************
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def pretty_address(address: typing.Tuple[str, int]) -> str:
|
def pretty_address(address: typing.Tuple[str, int]) -> str:
|
||||||
if ':' in address[0]:
|
if ':' in address[0]:
|
||||||
@ -331,11 +332,6 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
def pretty_destination(self) -> str:
|
def pretty_destination(self) -> str:
|
||||||
return TunnelProtocol.pretty_address(self.destination)
|
return TunnelProtocol.pretty_address(self.destination)
|
||||||
|
|
||||||
def close_connection(self):
|
|
||||||
try:
|
|
||||||
self.transport.close()
|
|
||||||
except Exception:
|
|
||||||
pass # Ignore errors
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _read_from_uds(
|
async def _read_from_uds(
|
||||||
@ -387,11 +383,11 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def notify_end_to_uds(
|
async def notify_end_to_uds(
|
||||||
cfg: config.ConfigurationType, ticket: bytes, counter: stats.Stats
|
cfg: config.ConfigurationType, ticket: bytes, stats_mngr: stats.StatsManager
|
||||||
) -> None:
|
) -> None:
|
||||||
await TunnelProtocol._read_from_uds(
|
await TunnelProtocol._read_from_uds(
|
||||||
cfg,
|
cfg,
|
||||||
ticket,
|
ticket,
|
||||||
'stop',
|
'stop',
|
||||||
{'sent': str(counter.sent), 'recv': str(counter.recv)},
|
{'sent': str(stats_mngr.sent), 'recv': str(stats_mngr.recv)},
|
||||||
)
|
)
|
||||||
|
88
tunnel-server/src/uds_tunnel/tunnel_client.py
Normal file
88
tunnel-server/src/uds_tunnel/tunnel_client.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
"""
|
||||||
|
Copyright (c) 2023 Adolfo Gómez García <dkmaster@dkmon.com>
|
||||||
|
|
||||||
|
This software is released under the MIT License.
|
||||||
|
https://opensource.org/licenses/MIT
|
||||||
|
"""
|
||||||
|
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
#
|
||||||
|
# Copyright (c) 2022 Virtual Cable S.L.U.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# Redistribution and use in source and binary forms, with or without modification,
|
||||||
|
# are permitted provided that the following conditions are met:
|
||||||
|
#
|
||||||
|
# * Redistributions of source code must retain the above copyright notice,
|
||||||
|
# this list of conditions and the following disclaimer.
|
||||||
|
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
# this list of conditions and the following disclaimer in the documentation
|
||||||
|
# and/or other materials provided with the distribution.
|
||||||
|
# * Neither the name of Virtual Cable S.L. nor the names of its contributors
|
||||||
|
# may be used to endorse or promote products derived from this software
|
||||||
|
# without specific prior written permission.
|
||||||
|
#
|
||||||
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
'''
|
||||||
|
Author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||||
|
'''
|
||||||
|
import asyncio
|
||||||
|
import typing
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from . import consts, config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
from . import tunnel, stats
|
||||||
|
|
||||||
|
|
||||||
|
# Protocol
|
||||||
|
class TunnelClientProtocol(asyncio.Protocol):
|
||||||
|
# Transport and other side of tunnel
|
||||||
|
transport: 'asyncio.transports.Transport'
|
||||||
|
receiver: 'tunnel.TunnelProtocol'
|
||||||
|
destination: typing.Tuple[str, int]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, receiver: 'tunnel.TunnelProtocol'
|
||||||
|
) -> None:
|
||||||
|
# If no other side is given, we are the server part
|
||||||
|
super().__init__()
|
||||||
|
# transport is undefined until connection_made is called
|
||||||
|
self.receiver = receiver
|
||||||
|
self.notify_ticket = b''
|
||||||
|
self.destination = ('', 0)
|
||||||
|
|
||||||
|
def data_received(self, data: bytes):
|
||||||
|
self.receiver.send(data)
|
||||||
|
|
||||||
|
def connection_made(self, transport: 'asyncio.transports.BaseTransport') -> None:
|
||||||
|
self.transport = typing.cast('asyncio.transports.Transport', transport)
|
||||||
|
|
||||||
|
def connection_lost(self, exc: typing.Optional[Exception]) -> None:
|
||||||
|
# Ensure close other side if not server_side
|
||||||
|
try:
|
||||||
|
self.receiver.close_connection()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def send(self, data: bytes):
|
||||||
|
self.transport.write(data)
|
||||||
|
|
||||||
|
def close_connection(self):
|
||||||
|
try:
|
||||||
|
if not self.transport.is_closing():
|
||||||
|
self.transport.close()
|
||||||
|
except Exception:
|
||||||
|
pass # Ignore errors
|
@ -189,8 +189,10 @@ async def tunnel_proc_async(
|
|||||||
logger.debug('Out of loop, stopping tasks: %s, running: %s', tasks, do_stop.is_set())
|
logger.debug('Out of loop, stopping tasks: %s, running: %s', tasks, do_stop.is_set())
|
||||||
|
|
||||||
# If any task is still running, cancel it
|
# If any task is still running, cancel it
|
||||||
for task in tasks:
|
asyncio.gather(*tasks, return_exceptions=True).cancel()
|
||||||
task.cancel()
|
|
||||||
|
# for task in tasks:
|
||||||
|
# task.cancel()
|
||||||
|
|
||||||
# Wait for all tasks to finish
|
# Wait for all tasks to finish
|
||||||
await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED)
|
await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user