mirror of
https://github.com/dkmstr/openuds.git
synced 2024-12-23 17:34:17 +03:00
updates from 4.0 backported
This commit is contained in:
parent
b7962a24f1
commit
adeb6b2a46
@ -126,9 +126,7 @@ async def main():
|
|||||||
data = client.recv(4)
|
data = client.recv(4)
|
||||||
print(data)
|
print(data)
|
||||||
# Upgrade connection to SSL, and use asyncio to handle the rest
|
# Upgrade connection to SSL, and use asyncio to handle the rest
|
||||||
transport: 'asyncio.transports.Transport'
|
(_, protocol) = await loop.connect_accepted_socket( # type: ignore
|
||||||
protocol: TunnelProtocol
|
|
||||||
(transport, protocol) = await loop.connect_accepted_socket( # type: ignore
|
|
||||||
lambda: TunnelProtocol(), client, ssl=context
|
lambda: TunnelProtocol(), client, ssl=context
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -43,18 +43,21 @@ class ConfigurationType(typing.NamedTuple):
|
|||||||
pidfile: str
|
pidfile: str
|
||||||
user: str
|
user: str
|
||||||
|
|
||||||
log_level: str
|
loglevel: str
|
||||||
log_file: str
|
logfile: str
|
||||||
log_size: int
|
logsize: int
|
||||||
log_number: int
|
lognumber: int
|
||||||
|
|
||||||
listen_address: str
|
listen_address: str
|
||||||
listen_port: int
|
listen_port: int
|
||||||
|
|
||||||
|
ipv6: bool
|
||||||
|
|
||||||
workers: int
|
workers: int
|
||||||
|
|
||||||
ssl_certificate: str
|
ssl_certificate: str
|
||||||
ssl_certificate_key: str
|
ssl_certificate_key: str
|
||||||
|
ssl_password: str
|
||||||
ssl_ciphers: str
|
ssl_ciphers: str
|
||||||
ssl_dhparam: str
|
ssl_dhparam: str
|
||||||
|
|
||||||
@ -111,15 +114,17 @@ def read(
|
|||||||
return ConfigurationType(
|
return ConfigurationType(
|
||||||
pidfile=uds.get('pidfile', ''),
|
pidfile=uds.get('pidfile', ''),
|
||||||
user=uds.get('user', ''),
|
user=uds.get('user', ''),
|
||||||
log_level=uds.get('loglevel', 'ERROR'),
|
loglevel=uds.get('loglevel', 'ERROR'),
|
||||||
log_file=uds.get('logfile', ''),
|
logfile=uds.get('logfile', ''),
|
||||||
log_size=int(logsize) * 1024 * 1024,
|
logsize=int(logsize) * 1024 * 1024,
|
||||||
log_number=int(uds.get('lognumber', '3')),
|
lognumber=int(uds.get('lognumber', '3')),
|
||||||
listen_address=uds.get('address', '0.0.0.0'),
|
listen_address=uds.get('address', '0.0.0.0'),
|
||||||
listen_port=int(uds.get('port', '443')),
|
listen_port=int(uds.get('port', '443')),
|
||||||
|
ipv6=uds.get('ipv6', 'false').lower() == 'true',
|
||||||
workers=int(uds.get('workers', '0')) or multiprocessing.cpu_count(),
|
workers=int(uds.get('workers', '0')) or multiprocessing.cpu_count(),
|
||||||
ssl_certificate=uds['ssl_certificate'],
|
ssl_certificate=uds['ssl_certificate'],
|
||||||
ssl_certificate_key=uds['ssl_certificate_key'],
|
ssl_certificate_key=uds.get('ssl_certificate_key', ''),
|
||||||
|
ssl_password=uds.get('ssl_password', ''),
|
||||||
ssl_ciphers=uds.get('ssl_ciphers'),
|
ssl_ciphers=uds.get('ssl_ciphers'),
|
||||||
ssl_dhparam=uds.get('ssl_dhparam'),
|
ssl_dhparam=uds.get('ssl_dhparam'),
|
||||||
uds_server=uds_server,
|
uds_server=uds_server,
|
||||||
|
@ -28,34 +28,48 @@
|
|||||||
'''
|
'''
|
||||||
Author: Adolfo Gómez, dkmaster at dkmon dot com
|
Author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||||
'''
|
'''
|
||||||
|
import typing
|
||||||
|
|
||||||
DEBUG = True
|
DEBUG = True
|
||||||
|
|
||||||
if DEBUG:
|
CONFIGFILE: typing.Final[str] = '/etc/udstunnel.conf' if not DEBUG else 'udstunnel.conf'
|
||||||
CONFIGFILE = 'udstunnel.conf'
|
LOGFORMAT: typing.Final[str] = (
|
||||||
LOGFORMAT = '%(levelname)s %(asctime)s %(message)s'
|
'%(levelname)s %(asctime)s %(message)s'
|
||||||
else:
|
if not DEBUG
|
||||||
CONFIGFILE = '/etc/udstunnel.conf'
|
else '%(levelname)s %(asctime)s %(message)s'
|
||||||
LOGFORMAT = '%(levelname)s %(asctime)s %(message)s'
|
)
|
||||||
|
|
||||||
# MAX Length of read buffer for proxyed requests
|
# MAX Length of read buffer for proxyed requests
|
||||||
BUFFER_SIZE = 1024 * 16
|
BUFFER_SIZE: typing.Final[int] = 1024 * 16
|
||||||
# Handshake for conversation start
|
# Handshake for conversation start
|
||||||
HANDSHAKE_V1 = b'\x5AMGB\xA5\x01\x00'
|
HANDSHAKE_V1: typing.Final[bytes] = b'\x5AMGB\xA5\x01\x00'
|
||||||
# Ticket length
|
# Ticket length
|
||||||
TICKET_LENGTH = 48
|
TICKET_LENGTH: typing.Final[int] = 48
|
||||||
# Max Admin password length (stats basically right now)
|
# Max Admin password length (stats basically right now)
|
||||||
PASSWORD_LENGTH = 64
|
PASSWORD_LENGTH: typing.Final[int] = 64
|
||||||
# Bandwidth calc time lapse
|
# Bandwidth calc time lapse
|
||||||
BANDWIDTH_TIME = 10
|
BANDWIDTH_TIME: typing.Final[int] = 10
|
||||||
|
|
||||||
# Commands LENGTH (all same length)
|
# Commands LENGTH (all same length)
|
||||||
COMMAND_LENGTH = 4
|
COMMAND_LENGTH: typing.Final[int] = 4
|
||||||
|
|
||||||
VERSION = 'v2.0.0'
|
VERSION: typing.Final[str] = 'v2.0.0'
|
||||||
|
|
||||||
# Valid commands
|
# Valid commands
|
||||||
COMMAND_OPEN = b'OPEN'
|
COMMAND_OPEN: typing.Final[bytes] = b'OPEN'
|
||||||
COMMAND_TEST = b'TEST'
|
COMMAND_TEST: typing.Final[bytes] = b'TEST'
|
||||||
COMMAND_STAT = b'STAT' # full stats
|
COMMAND_STAT: typing.Final[bytes] = b'STAT' # full stats
|
||||||
COMMAND_INFO = b'INFO' # Basic stats, currently same as FULL
|
COMMAND_INFO: typing.Final[bytes] = b'INFO' # Basic stats, currently same as FULL
|
||||||
|
|
||||||
|
RESPONSE_ERROR_TICKET: typing.Final[bytes] = b'ERROR_TICKET'
|
||||||
|
RESPONSE_ERROR_COMMAND: typing.Final[bytes] = b'ERROR_COMMAND'
|
||||||
|
RESPONSE_ERROR_TIMEOUT: typing.Final[bytes] = b'TIMEOUT'
|
||||||
|
RESPONSE_FORBIDDEN: typing.Final[bytes] = b'FORBIDDEN'
|
||||||
|
|
||||||
|
RESPONSE_OK: typing.Final[bytes] = b'OK'
|
||||||
|
|
||||||
|
# Timeout for command
|
||||||
|
TIMEOUT_COMMAND: typing.Final[int] = 3
|
||||||
|
|
||||||
|
# Backlog for listen socket
|
||||||
|
BACKLOG = 1024
|
||||||
|
@ -156,13 +156,16 @@ class Processes:
|
|||||||
ns: 'Namespace',
|
ns: 'Namespace',
|
||||||
) -> None:
|
) -> None:
|
||||||
if cfg.use_uvloop:
|
if cfg.use_uvloop:
|
||||||
import uvloop
|
try:
|
||||||
|
import uvloop
|
||||||
|
|
||||||
if sys.version_info >= (3, 11):
|
if sys.version_info >= (3, 11):
|
||||||
with asyncio.Runner(loop_factory=uvloop.new_event_loop) as runner:
|
with asyncio.Runner(loop_factory=uvloop.new_event_loop) as runner:
|
||||||
runner.run(proc(conn, cfg, ns))
|
runner.run(proc(conn, cfg, ns))
|
||||||
else:
|
else:
|
||||||
uvloop.install()
|
uvloop.install()
|
||||||
asyncio.run(proc(conn, cfg, ns))
|
asyncio.run(proc(conn, cfg, ns))
|
||||||
|
except ImportError:
|
||||||
|
logger.warning('uvloop not found, using default asyncio')
|
||||||
else:
|
else:
|
||||||
asyncio.run(proc(conn, cfg, ns))
|
asyncio.run(proc(conn, cfg, ns))
|
||||||
|
@ -46,6 +46,7 @@ logger = logging.getLogger(__name__)
|
|||||||
class Proxy:
|
class Proxy:
|
||||||
cfg: 'config.ConfigurationType'
|
cfg: 'config.ConfigurationType'
|
||||||
ns: 'Namespace'
|
ns: 'Namespace'
|
||||||
|
finished: asyncio.Event
|
||||||
|
|
||||||
def __init__(self, cfg: 'config.ConfigurationType', ns: 'Namespace') -> None:
|
def __init__(self, cfg: 'config.ConfigurationType', ns: 'Namespace') -> None:
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
@ -63,22 +64,27 @@ class Proxy:
|
|||||||
addr = source.getpeername()
|
addr = source.getpeername()
|
||||||
except Exception:
|
except Exception:
|
||||||
addr = 'Unknown'
|
addr = 'Unknown'
|
||||||
logger.error('Proxy error from %s: %s', addr, e)
|
logger.exception('Proxy error from %s: %s (%s--%s)', addr, e, source, context)
|
||||||
|
|
||||||
async def proxy(self, source: socket.socket, context: 'ssl.SSLContext') -> None:
|
async def proxy(self, source: socket.socket, context: 'ssl.SSLContext') -> None:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_running_loop()
|
||||||
# Handshake correct in this point, upgrade the connection to TSL and let
|
# Handshake correct in this point, upgrade the connection to TSL and let
|
||||||
# the protocol controller do the rest
|
# the protocol controller do the rest
|
||||||
|
self.finished = asyncio.Event()
|
||||||
|
|
||||||
# Upgrade connection to SSL, and use asyncio to handle the rest
|
# Upgrade connection to SSL, and use asyncio to handle the rest
|
||||||
try:
|
try:
|
||||||
protocol: tunnel.TunnelProtocol
|
def factory() -> tunnel.TunnelProtocol:
|
||||||
# (connect accepted loop not present on AbastractEventLoop definition < 3.10)
|
return tunnel.TunnelProtocol(self)
|
||||||
(_, protocol) = await loop.connect_accepted_socket( # type: ignore
|
# (connect accepted loop not present on AbastractEventLoop definition < 3.10), that's why we use ignore
|
||||||
lambda: tunnel.TunnelProtocol(self), source, ssl=context
|
await loop.connect_accepted_socket( # type: ignore
|
||||||
|
factory, source, ssl=context
|
||||||
)
|
)
|
||||||
|
|
||||||
await protocol.finished
|
# Wait for connection to be closed
|
||||||
|
await self.finished.wait()
|
||||||
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass # Return on cancel
|
pass # Return on cancel
|
||||||
|
|
||||||
|
@ -31,6 +31,7 @@ Author: Adolfo Gómez, dkmaster at dkmon dot com
|
|||||||
import asyncio
|
import asyncio
|
||||||
import typing
|
import typing
|
||||||
import logging
|
import logging
|
||||||
|
import socket
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
@ -46,8 +47,6 @@ if typing.TYPE_CHECKING:
|
|||||||
|
|
||||||
# Protocol
|
# Protocol
|
||||||
class TunnelProtocol(asyncio.Protocol):
|
class TunnelProtocol(asyncio.Protocol):
|
||||||
# future to mark eof
|
|
||||||
finished: asyncio.Future
|
|
||||||
# Transport and other side of tunnel
|
# Transport and other side of tunnel
|
||||||
transport: 'asyncio.transports.Transport'
|
transport: 'asyncio.transports.Transport'
|
||||||
other_side: 'TunnelProtocol'
|
other_side: 'TunnelProtocol'
|
||||||
@ -56,7 +55,7 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
# Command buffer
|
# Command buffer
|
||||||
cmd: bytes
|
cmd: bytes
|
||||||
# Ticket
|
# Ticket
|
||||||
notify_ticket: bytes
|
notify_ticket: bytes # Only exists on "slave" transport (that is, tunnel from us to remote machine)
|
||||||
# owner Proxy class
|
# owner Proxy class
|
||||||
owner: 'proxy.Proxy'
|
owner: 'proxy.Proxy'
|
||||||
# source of connection
|
# source of connection
|
||||||
@ -66,6 +65,8 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
stats_manager: stats.Stats
|
stats_manager: stats.Stats
|
||||||
# counter
|
# counter
|
||||||
counter: stats.StatsSingleCounter
|
counter: stats.StatsSingleCounter
|
||||||
|
# If there is a timeout task running
|
||||||
|
timeout_task: typing.Optional[asyncio.Task] = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, owner: 'proxy.Proxy', other_side: typing.Optional['TunnelProtocol'] = None
|
self, owner: 'proxy.Proxy', other_side: typing.Optional['TunnelProtocol'] = None
|
||||||
@ -82,19 +83,23 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
self.stats_manager = stats.Stats(owner.ns)
|
self.stats_manager = stats.Stats(owner.ns)
|
||||||
self.counter = self.stats_manager.as_sent_counter()
|
self.counter = self.stats_manager.as_sent_counter()
|
||||||
self.runner = self.do_command
|
self.runner = self.do_command
|
||||||
|
# Set starting timeout task, se we dont get hunged on connections without data
|
||||||
|
self.set_timeout(consts.TIMEOUT_COMMAND)
|
||||||
|
|
||||||
# transport is undefined until connection_made is called
|
# transport is undefined until connection_made is called
|
||||||
self.finished = asyncio.Future()
|
|
||||||
self.cmd = b''
|
self.cmd = b''
|
||||||
self.notify_ticket = b''
|
self.notify_ticket = b''
|
||||||
self.owner = owner
|
self.owner = owner
|
||||||
self.source = ('', 0)
|
self.source = ('', 0)
|
||||||
self.destination = ('', 0)
|
self.destination = ('', 0)
|
||||||
|
|
||||||
|
|
||||||
def process_open(self) -> None:
|
def process_open(self) -> None:
|
||||||
# Open Command has the ticket behind it
|
# Open Command has the ticket behind it
|
||||||
|
|
||||||
if len(self.cmd) < consts.TICKET_LENGTH + consts.COMMAND_LENGTH:
|
if len(self.cmd) < consts.TICKET_LENGTH + consts.COMMAND_LENGTH:
|
||||||
|
# Reactivate timeout, will be deactivated on do_command
|
||||||
|
self.set_timeout(consts.TIMEOUT_COMMAND)
|
||||||
return # Wait for more data to complete OPEN command
|
return # Wait for more data to complete OPEN command
|
||||||
|
|
||||||
# Ticket received, now process it with UDS
|
# Ticket received, now process it with UDS
|
||||||
@ -106,7 +111,7 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
# clean up the command
|
# clean up the command
|
||||||
self.cmd = b''
|
self.cmd = b''
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
async def open_other_side() -> None:
|
async def open_other_side() -> None:
|
||||||
try:
|
try:
|
||||||
@ -115,7 +120,7 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error('ERROR %s', e.args[0] if e.args else e)
|
logger.error('ERROR %s', e.args[0] if e.args else e)
|
||||||
self.transport.write(b'ERROR_TICKET')
|
self.transport.write(consts.RESPONSE_ERROR_TICKET)
|
||||||
self.transport.close() # And force close
|
self.transport.close() # And force close
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -130,10 +135,12 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
family = socket.AF_INET6 if ':' in self.destination[0] or self.owner.cfg.ipv6 else socket.AF_INET
|
||||||
(_, protocol) = await loop.create_connection(
|
(_, protocol) = await loop.create_connection(
|
||||||
lambda: TunnelProtocol(self.owner, self),
|
lambda: TunnelProtocol(self.owner, self),
|
||||||
self.destination[0],
|
self.destination[0],
|
||||||
self.destination[1],
|
self.destination[1],
|
||||||
|
family=family,
|
||||||
)
|
)
|
||||||
self.other_side = typing.cast('TunnelProtocol', protocol)
|
self.other_side = typing.cast('TunnelProtocol', protocol)
|
||||||
|
|
||||||
@ -145,6 +152,7 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
logger.error('Error opening connection: %s', e)
|
logger.error('Error opening connection: %s', e)
|
||||||
self.close_connection()
|
self.close_connection()
|
||||||
|
|
||||||
|
# add open other side to the loop
|
||||||
loop.create_task(open_other_side())
|
loop.create_task(open_other_side())
|
||||||
# From now, proxy connection
|
# From now, proxy connection
|
||||||
self.runner = self.do_proxy
|
self.runner = self.do_proxy
|
||||||
@ -160,7 +168,7 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
# Check valid source ip
|
# Check valid source ip
|
||||||
if self.transport.get_extra_info('peername')[0] not in self.owner.cfg.allow:
|
if self.transport.get_extra_info('peername')[0] not in self.owner.cfg.allow:
|
||||||
# Invalid source
|
# Invalid source
|
||||||
self.transport.write(b'FORBIDDEN')
|
self.transport.write(consts.RESPONSE_FORBIDDEN)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Check password
|
# Check password
|
||||||
@ -171,7 +179,7 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
|
|
||||||
if passwd.decode(errors='ignore') != self.owner.cfg.secret:
|
if passwd.decode(errors='ignore') != self.owner.cfg.secret:
|
||||||
# Invalid password
|
# Invalid password
|
||||||
self.transport.write(b'FORBIDDEN')
|
self.transport.write(consts.RESPONSE_FORBIDDEN)
|
||||||
return
|
return
|
||||||
|
|
||||||
data = stats.GlobalStats.get_stats(self.owner.ns)
|
data = stats.GlobalStats.get_stats(self.owner.ns)
|
||||||
@ -184,18 +192,50 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
finally:
|
finally:
|
||||||
self.close_connection()
|
self.close_connection()
|
||||||
|
|
||||||
|
async def timeout(self, wait: int) -> None:
|
||||||
|
""" Timeout can only occur while waiting for a command."""
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(wait)
|
||||||
|
logger.error('TIMEOUT FROM %s', self.pretty_source())
|
||||||
|
self.transport.write(consts.RESPONSE_ERROR_TIMEOUT)
|
||||||
|
self.close_connection()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def set_timeout(self, wait: int) -> None:
|
||||||
|
"""Set a timeout for this connection.
|
||||||
|
If reached, the connection will be closed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
wait (int): Timeout in seconds
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self.timeout_task:
|
||||||
|
self.timeout_task.cancel()
|
||||||
|
self.timeout_task = asyncio.create_task(self.timeout(wait))
|
||||||
|
|
||||||
|
def clean_timeout(self) -> None:
|
||||||
|
"""Clean the timeout task if any.
|
||||||
|
"""
|
||||||
|
if self.timeout_task:
|
||||||
|
self.timeout_task.cancel()
|
||||||
|
self.timeout_task = None
|
||||||
|
|
||||||
def do_command(self, data: bytes) -> None:
|
def do_command(self, data: bytes) -> None:
|
||||||
self.cmd += data
|
if self.cmd == b'':
|
||||||
if len(self.cmd) >= consts.COMMAND_LENGTH:
|
|
||||||
logger.info('CONNECT FROM %s', self.pretty_source())
|
logger.info('CONNECT FROM %s', self.pretty_source())
|
||||||
|
|
||||||
|
self.clean_timeout()
|
||||||
|
self.cmd += data
|
||||||
|
# Ensure we don't get a timeout
|
||||||
|
if len(self.cmd) >= consts.COMMAND_LENGTH:
|
||||||
command = self.cmd[: consts.COMMAND_LENGTH]
|
command = self.cmd[: consts.COMMAND_LENGTH]
|
||||||
try:
|
try:
|
||||||
if command == consts.COMMAND_OPEN:
|
if command == consts.COMMAND_OPEN:
|
||||||
self.process_open()
|
self.process_open()
|
||||||
elif command == consts.COMMAND_TEST:
|
elif command == consts.COMMAND_TEST:
|
||||||
logger.info('COMMAND: TEST')
|
logger.info('COMMAND: TEST')
|
||||||
self.transport.write(b'OK')
|
self.transport.write(consts.RESPONSE_OK)
|
||||||
self.close_connection()
|
self.close_connection()
|
||||||
return
|
return
|
||||||
elif command in (consts.COMMAND_STAT, consts.COMMAND_INFO):
|
elif command in (consts.COMMAND_STAT, consts.COMMAND_INFO):
|
||||||
@ -206,9 +246,11 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
raise Exception('Invalid command')
|
raise Exception('Invalid command')
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.error('ERROR from %s', self.pretty_source())
|
logger.error('ERROR from %s', self.pretty_source())
|
||||||
self.transport.write(b'ERROR_COMMAND')
|
self.transport.write(consts.RESPONSE_ERROR_COMMAND)
|
||||||
self.close_connection()
|
self.close_connection()
|
||||||
return
|
return
|
||||||
|
else:
|
||||||
|
self.set_timeout(consts.TIMEOUT_COMMAND)
|
||||||
|
|
||||||
# if not enough data to process command, wait for more
|
# if not enough data to process command, wait for more
|
||||||
|
|
||||||
@ -221,6 +263,7 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
|
|
||||||
def connection_made(self, transport: 'asyncio.transports.BaseTransport') -> None:
|
def connection_made(self, transport: 'asyncio.transports.BaseTransport') -> None:
|
||||||
logger.debug('Connection made: %s', transport.get_extra_info('peername'))
|
logger.debug('Connection made: %s', transport.get_extra_info('peername'))
|
||||||
|
self.main = True # This is the main connection
|
||||||
|
|
||||||
# We know for sure that the transport is a Transport.
|
# We know for sure that the transport is a Transport.
|
||||||
self.transport = typing.cast('asyncio.transports.Transport', transport)
|
self.transport = typing.cast('asyncio.transports.Transport', transport)
|
||||||
@ -239,10 +282,12 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.notify_ticket = b'' # Clean up so no more notifications
|
self.notify_ticket = b'' # Clean up so no more notifications
|
||||||
|
else: # No ticket, this is "main" connection (from client to us). Notify owner that we are done
|
||||||
|
self.owner.finished.set()
|
||||||
|
|
||||||
def connection_lost(self, exc: typing.Optional[Exception]) -> None:
|
def connection_lost(self, exc: typing.Optional[Exception]) -> None:
|
||||||
logger.debug('Connection closed : %s', exc)
|
logger.debug('Connection closed : %s', exc)
|
||||||
self.finished.set_result(True)
|
# Ensure close other side if any
|
||||||
if self.other_side is not self:
|
if self.other_side is not self:
|
||||||
self.other_side.transport.close()
|
self.other_side.transport.close()
|
||||||
else:
|
else:
|
||||||
@ -250,12 +295,17 @@ class TunnelProtocol(asyncio.Protocol):
|
|||||||
self.notifyEnd()
|
self.notifyEnd()
|
||||||
|
|
||||||
# helpers
|
# helpers
|
||||||
|
@staticmethod
|
||||||
|
def pretty_address(address: typing.Tuple[str, int]) -> str:
|
||||||
|
if ':' in address[0]:
|
||||||
|
return '[' + address[0] + ']:' + str(address[1])
|
||||||
|
return address[0] + ':' + str(address[1])
|
||||||
# source address, pretty format
|
# source address, pretty format
|
||||||
def pretty_source(self) -> str:
|
def pretty_source(self) -> str:
|
||||||
return self.source[0] + ':' + str(self.source[1])
|
return TunnelProtocol.pretty_address(self.source)
|
||||||
|
|
||||||
def pretty_destination(self) -> str:
|
def pretty_destination(self) -> str:
|
||||||
return self.destination[0] + ':' + str(self.destination[1])
|
return TunnelProtocol.pretty_address(self.destination)
|
||||||
|
|
||||||
def close_connection(self):
|
def close_connection(self):
|
||||||
self.transport.close()
|
self.transport.close()
|
||||||
|
@ -20,14 +20,21 @@ lognumber = 3
|
|||||||
# Listen address. Defaults to 0.0.0.0
|
# Listen address. Defaults to 0.0.0.0
|
||||||
address = 0.0.0.0
|
address = 0.0.0.0
|
||||||
|
|
||||||
# Number of workers. Defaults to 0 (means "as much as cores")
|
|
||||||
workers = 2
|
|
||||||
|
|
||||||
# Listening port
|
# Listening port
|
||||||
port = 7777
|
port = 7777
|
||||||
|
|
||||||
|
# If force ipv6, defaults to false
|
||||||
|
# Note: if listen address is an ipv6 address, this will be forced to true
|
||||||
|
# This will force dns resolution to ipv6
|
||||||
|
ipv6 = false
|
||||||
|
|
||||||
|
# Number of workers. Defaults to 0 (means "as much as cores")
|
||||||
|
workers = 2
|
||||||
|
|
||||||
|
|
||||||
# SSL Related parameters.
|
# SSL Related parameters.
|
||||||
ssl_certificate = /etc/certs/server.pem
|
ssl_certificate = /etc/certs/server.pem
|
||||||
|
# Key can be included on certificate file, so this is optional
|
||||||
ssl_certificate_key = /etc/certs/key.pem
|
ssl_certificate_key = /etc/certs/key.pem
|
||||||
# ssl_ciphers and ssl_dhparam are optional.
|
# ssl_ciphers and ssl_dhparam are optional.
|
||||||
ssl_ciphers = ECDHE-RSA-AES256-GCM-SHA512:DHE-RSA-AES256-GCM-SHA512:ECDHE-RSA-AES256-GCM-SHA384:DHE-RSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-SHA384
|
ssl_ciphers = ECDHE-RSA-AES256-GCM-SHA512:DHE-RSA-AES256-GCM-SHA512:ECDHE-RSA-AES256-GCM-SHA384:DHE-RSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-SHA384
|
||||||
@ -40,6 +47,11 @@ ssl_dhparam = /etc/certs/dhparam.pem
|
|||||||
# https://www.example.com:14333/uds/rest/tunnel/ticket
|
# https://www.example.com:14333/uds/rest/tunnel/ticket
|
||||||
uds_server = http://172.27.0.1:8000/uds/rest/tunnel/ticket
|
uds_server = http://172.27.0.1:8000/uds/rest/tunnel/ticket
|
||||||
uds_token = eBCeFxTBw1IKXCqq-RlncshwWIfrrqxc8y5nehqiqMtRztwD
|
uds_token = eBCeFxTBw1IKXCqq-RlncshwWIfrrqxc8y5nehqiqMtRztwD
|
||||||
|
# Defaults to 10 seconds
|
||||||
|
# uds_timeout = 10
|
||||||
|
|
||||||
|
# If verify ssl certificate on uds server. Defaults to true
|
||||||
|
# uds_verify_ssl = true
|
||||||
|
|
||||||
# Secret to get access to admin commands (Currently only stats commands). No default for this.
|
# Secret to get access to admin commands (Currently only stats commands). No default for this.
|
||||||
# Admin commands and only allowed from "allow" ips
|
# Admin commands and only allowed from "allow" ips
|
||||||
@ -50,3 +62,7 @@ secret = MySecret
|
|||||||
# Only use IPs, no networks allowed
|
# Only use IPs, no networks allowed
|
||||||
# defaults to localhost (change if listen address is different from 0.0.0.0)
|
# defaults to localhost (change if listen address is different from 0.0.0.0)
|
||||||
allow = 127.0.0.1
|
allow = 127.0.0.1
|
||||||
|
|
||||||
|
|
||||||
|
# If use uvloop as event loop. Defaults to true
|
||||||
|
# use_uvloop = true
|
@ -39,6 +39,8 @@ import ssl
|
|||||||
import socket
|
import socket
|
||||||
import logging
|
import logging
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
# event for stop notification
|
||||||
|
import threading
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -59,16 +61,15 @@ if typing.TYPE_CHECKING:
|
|||||||
from multiprocessing.connection import Connection
|
from multiprocessing.connection import Connection
|
||||||
from multiprocessing.managers import Namespace
|
from multiprocessing.managers import Namespace
|
||||||
|
|
||||||
BACKLOG = 1024
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
do_stop = False
|
running: threading.Event = threading.Event()
|
||||||
|
|
||||||
|
|
||||||
def stop_signal(signum: int, frame: typing.Any) -> None:
|
def stop_signal(signum: int, frame: typing.Any) -> None:
|
||||||
global do_stop
|
global running
|
||||||
do_stop = True
|
running.clear()
|
||||||
logger.debug('SIGNAL %s, frame: %s', signum, frame)
|
logger.debug('SIGNAL %s, frame: %s', signum, frame)
|
||||||
|
|
||||||
|
|
||||||
@ -76,26 +77,26 @@ def setup_log(cfg: config.ConfigurationType) -> None:
|
|||||||
from logging.handlers import RotatingFileHandler
|
from logging.handlers import RotatingFileHandler
|
||||||
|
|
||||||
# Update logging if needed
|
# Update logging if needed
|
||||||
if cfg.log_file:
|
if cfg.logfile:
|
||||||
fileh = RotatingFileHandler(
|
fileh = RotatingFileHandler(
|
||||||
filename=cfg.log_file,
|
filename=cfg.logfile,
|
||||||
mode='a',
|
mode='a',
|
||||||
maxBytes=cfg.log_size,
|
maxBytes=cfg.logsize,
|
||||||
backupCount=cfg.log_number,
|
backupCount=cfg.lognumber,
|
||||||
)
|
)
|
||||||
formatter = logging.Formatter(consts.LOGFORMAT)
|
formatter = logging.Formatter(consts.LOGFORMAT)
|
||||||
fileh.setFormatter(formatter)
|
fileh.setFormatter(formatter)
|
||||||
log = logging.getLogger()
|
log = logging.getLogger()
|
||||||
log.setLevel(cfg.log_level)
|
log.setLevel(cfg.loglevel)
|
||||||
# for hdlr in log.handlers[:]:
|
# for hdlr in log.handlers[:]:
|
||||||
# log.removeHandler(hdlr)
|
# log.removeHandler(hdlr)
|
||||||
log.addHandler(fileh)
|
log.addHandler(fileh)
|
||||||
else:
|
else:
|
||||||
# Setup basic logging
|
# Setup basic logging
|
||||||
log = logging.getLogger()
|
log = logging.getLogger()
|
||||||
log.setLevel(cfg.log_level)
|
log.setLevel(cfg.loglevel)
|
||||||
handler = logging.StreamHandler(sys.stderr)
|
handler = logging.StreamHandler(sys.stderr)
|
||||||
handler.setLevel(cfg.log_level)
|
handler.setLevel(cfg.loglevel)
|
||||||
formatter = logging.Formatter(
|
formatter = logging.Formatter(
|
||||||
'%(levelname)s - %(message)s'
|
'%(levelname)s - %(message)s'
|
||||||
) # Basic log format, nice for syslog
|
) # Basic log format, nice for syslog
|
||||||
@ -107,10 +108,7 @@ async def tunnel_proc_async(
|
|||||||
pipe: 'Connection', cfg: config.ConfigurationType, ns: 'Namespace'
|
pipe: 'Connection', cfg: config.ConfigurationType, ns: 'Namespace'
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
try:
|
loop = asyncio.get_running_loop()
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
except RuntimeError: # older python versions
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
|
|
||||||
tasks: typing.List[asyncio.Task] = []
|
tasks: typing.List[asyncio.Task] = []
|
||||||
|
|
||||||
@ -123,8 +121,13 @@ async def tunnel_proc_async(
|
|||||||
] = pipe.recv()
|
] = pipe.recv()
|
||||||
if msg:
|
if msg:
|
||||||
return msg
|
return msg
|
||||||
|
except EOFError:
|
||||||
|
logger.debug('Parent process closed connection')
|
||||||
|
pipe.close()
|
||||||
|
return None, None
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception('Receiving data from parent process')
|
logger.exception('Receiving data from parent process')
|
||||||
|
pipe.close()
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
async def run_server() -> None:
|
async def run_server() -> None:
|
||||||
@ -133,7 +136,15 @@ async def tunnel_proc_async(
|
|||||||
|
|
||||||
# Generate SSL context
|
# Generate SSL context
|
||||||
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||||
context.load_cert_chain(cfg.ssl_certificate, cfg.ssl_certificate_key)
|
args: typing.Dict[str, typing.Any] = {
|
||||||
|
'certfile': cfg.ssl_certificate,
|
||||||
|
}
|
||||||
|
if cfg.ssl_certificate_key:
|
||||||
|
args['keyfile'] = cfg.ssl_certificate_key
|
||||||
|
if cfg.ssl_password:
|
||||||
|
args['password'] = cfg.ssl_password
|
||||||
|
|
||||||
|
context.load_cert_chain(**args)
|
||||||
|
|
||||||
if cfg.ssl_ciphers:
|
if cfg.ssl_ciphers:
|
||||||
context.set_ciphers(cfg.ssl_ciphers)
|
context.set_ciphers(cfg.ssl_ciphers)
|
||||||
@ -141,29 +152,44 @@ async def tunnel_proc_async(
|
|||||||
if cfg.ssl_dhparam:
|
if cfg.ssl_dhparam:
|
||||||
context.load_dh_params(cfg.ssl_dhparam)
|
context.load_dh_params(cfg.ssl_dhparam)
|
||||||
|
|
||||||
while True:
|
try:
|
||||||
address: typing.Optional[typing.Tuple[str, int]] = ('', 0)
|
while True:
|
||||||
try:
|
address: typing.Optional[typing.Tuple[str, int]] = ('', 0)
|
||||||
(sock, address) = await loop.run_in_executor(None, get_socket)
|
try:
|
||||||
if not sock:
|
(sock, address) = await loop.run_in_executor(None, get_socket)
|
||||||
break # No more sockets, exit
|
if not sock:
|
||||||
logger.debug(f'CONNECTION from {address!r} (pid: {os.getpid()})')
|
break # No more sockets, exit
|
||||||
tasks.append(asyncio.create_task(tunneler(sock, context)))
|
logger.debug(f'CONNECTION from {address!r} (pid: {os.getpid()})')
|
||||||
except Exception:
|
tasks.append(asyncio.create_task(tunneler(sock, context)))
|
||||||
logger.error('NEGOTIATION ERROR from %s', address[0] if address else 'unknown')
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
logger.error('NEGOTIATION ERROR from %s', address[0] if address else 'unknown')
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass # Stop
|
||||||
|
|
||||||
# create task for server
|
# create task for server
|
||||||
tasks.append(asyncio.create_task(run_server()))
|
tasks.append(asyncio.create_task(run_server()))
|
||||||
|
|
||||||
while tasks:
|
try:
|
||||||
to_wait = tasks[:] # Get a copy of the list, and clean the original
|
while tasks and running.is_set():
|
||||||
# Wait for tasks to finish
|
to_wait = tasks[:] # Get a copy of the list, and clean the original
|
||||||
done, _ = await asyncio.wait(to_wait, return_when=asyncio.FIRST_COMPLETED)
|
# Wait for tasks to finish
|
||||||
# Remove finished tasks
|
done, _ = await asyncio.wait(to_wait, return_when=asyncio.FIRST_COMPLETED, timeout=2)
|
||||||
for task in done:
|
# Remove finished tasks
|
||||||
tasks.remove(task)
|
for task in done:
|
||||||
if task.exception():
|
tasks.remove(task)
|
||||||
logger.exception('TUNNEL ERROR')
|
if task.exception():
|
||||||
|
logger.exception('TUNNEL ERROR')
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
running.clear() # ensure we stop
|
||||||
|
|
||||||
|
# If any task is still running, cancel it
|
||||||
|
for task in tasks:
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
# Wait for all tasks to finish
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
logger.info('PROCESS %s stopped', os.getpid())
|
logger.info('PROCESS %s stopped', os.getpid())
|
||||||
|
|
||||||
@ -205,11 +231,11 @@ def tunnel_main(args: 'argparse.Namespace') -> None:
|
|||||||
# logger.warning('socket.REUSEPORT not available')
|
# logger.warning('socket.REUSEPORT not available')
|
||||||
try:
|
try:
|
||||||
sock.bind((cfg.listen_address, cfg.listen_port))
|
sock.bind((cfg.listen_address, cfg.listen_port))
|
||||||
sock.listen(BACKLOG)
|
sock.listen(consts.BACKLOG)
|
||||||
|
|
||||||
# If running as root, and requested drop privileges after port bind
|
# If running as root, and requested drop privileges after port bind
|
||||||
if os.getuid() == 0 and cfg.user:
|
if os.getuid() == 0 and cfg.user:
|
||||||
logger.debug('Changing to user %s', cfg.user)
|
logger.debug('Changing to user %s', cfg.user)
|
||||||
pwu = pwd.getpwnam(cfg.user)
|
pwu = pwd.getpwnam(cfg.user)
|
||||||
# os.setgid(pwu.pw_gid)
|
# os.setgid(pwu.pw_gid)
|
||||||
os.setuid(pwu.pw_uid)
|
os.setuid(pwu.pw_uid)
|
||||||
@ -233,16 +259,22 @@ def tunnel_main(args: 'argparse.Namespace') -> None:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Setup signal handlers
|
# Setup signal handlers
|
||||||
signal.signal(signal.SIGINT, stop_signal)
|
try:
|
||||||
signal.signal(signal.SIGTERM, stop_signal)
|
signal.signal(signal.SIGINT, stop_signal)
|
||||||
|
signal.signal(signal.SIGTERM, stop_signal)
|
||||||
|
except Exception as e:
|
||||||
|
# Signal not available on threads, and we use threads on tests,
|
||||||
|
# so we will ignore this because on tests signals are not important
|
||||||
|
logger.warning('Signal not available: %s', e)
|
||||||
|
|
||||||
stats_collector = stats.GlobalStats()
|
stats_collector = stats.GlobalStats()
|
||||||
|
|
||||||
prcs = processes.Processes(tunnel_proc_async, cfg, stats_collector.ns)
|
prcs = processes.Processes(tunnel_proc_async, cfg, stats_collector.ns)
|
||||||
|
|
||||||
|
running.set() # Signal we are running
|
||||||
with ThreadPoolExecutor(max_workers=256) as executor:
|
with ThreadPoolExecutor(max_workers=256) as executor:
|
||||||
try:
|
try:
|
||||||
while not do_stop:
|
while running.is_set():
|
||||||
try:
|
try:
|
||||||
client, addr = sock.accept()
|
client, addr = sock.accept()
|
||||||
logger.info('CONNECTION from %s', addr)
|
logger.info('CONNECTION from %s', addr)
|
||||||
|
Loading…
Reference in New Issue
Block a user