1
0
mirror of https://github.com/dkmstr/openuds.git synced 2024-12-23 17:34:17 +03:00

updates from 4.0 backported

This commit is contained in:
Adolfo Gómez García 2022-12-19 01:25:29 +01:00
parent b7962a24f1
commit adeb6b2a46
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
8 changed files with 225 additions and 101 deletions

View File

@ -126,9 +126,7 @@ async def main():
data = client.recv(4) data = client.recv(4)
print(data) print(data)
# Upgrade connection to SSL, and use asyncio to handle the rest # Upgrade connection to SSL, and use asyncio to handle the rest
transport: 'asyncio.transports.Transport' (_, protocol) = await loop.connect_accepted_socket( # type: ignore
protocol: TunnelProtocol
(transport, protocol) = await loop.connect_accepted_socket( # type: ignore
lambda: TunnelProtocol(), client, ssl=context lambda: TunnelProtocol(), client, ssl=context
) )

View File

@ -43,18 +43,21 @@ class ConfigurationType(typing.NamedTuple):
pidfile: str pidfile: str
user: str user: str
log_level: str loglevel: str
log_file: str logfile: str
log_size: int logsize: int
log_number: int lognumber: int
listen_address: str listen_address: str
listen_port: int listen_port: int
ipv6: bool
workers: int workers: int
ssl_certificate: str ssl_certificate: str
ssl_certificate_key: str ssl_certificate_key: str
ssl_password: str
ssl_ciphers: str ssl_ciphers: str
ssl_dhparam: str ssl_dhparam: str
@ -111,15 +114,17 @@ def read(
return ConfigurationType( return ConfigurationType(
pidfile=uds.get('pidfile', ''), pidfile=uds.get('pidfile', ''),
user=uds.get('user', ''), user=uds.get('user', ''),
log_level=uds.get('loglevel', 'ERROR'), loglevel=uds.get('loglevel', 'ERROR'),
log_file=uds.get('logfile', ''), logfile=uds.get('logfile', ''),
log_size=int(logsize) * 1024 * 1024, logsize=int(logsize) * 1024 * 1024,
log_number=int(uds.get('lognumber', '3')), lognumber=int(uds.get('lognumber', '3')),
listen_address=uds.get('address', '0.0.0.0'), listen_address=uds.get('address', '0.0.0.0'),
listen_port=int(uds.get('port', '443')), listen_port=int(uds.get('port', '443')),
ipv6=uds.get('ipv6', 'false').lower() == 'true',
workers=int(uds.get('workers', '0')) or multiprocessing.cpu_count(), workers=int(uds.get('workers', '0')) or multiprocessing.cpu_count(),
ssl_certificate=uds['ssl_certificate'], ssl_certificate=uds['ssl_certificate'],
ssl_certificate_key=uds['ssl_certificate_key'], ssl_certificate_key=uds.get('ssl_certificate_key', ''),
ssl_password=uds.get('ssl_password', ''),
ssl_ciphers=uds.get('ssl_ciphers'), ssl_ciphers=uds.get('ssl_ciphers'),
ssl_dhparam=uds.get('ssl_dhparam'), ssl_dhparam=uds.get('ssl_dhparam'),
uds_server=uds_server, uds_server=uds_server,

View File

@ -28,34 +28,48 @@
''' '''
Author: Adolfo Gómez, dkmaster at dkmon dot com Author: Adolfo Gómez, dkmaster at dkmon dot com
''' '''
import typing
DEBUG = True DEBUG = True
if DEBUG: CONFIGFILE: typing.Final[str] = '/etc/udstunnel.conf' if not DEBUG else 'udstunnel.conf'
CONFIGFILE = 'udstunnel.conf' LOGFORMAT: typing.Final[str] = (
LOGFORMAT = '%(levelname)s %(asctime)s %(message)s' '%(levelname)s %(asctime)s %(message)s'
else: if not DEBUG
CONFIGFILE = '/etc/udstunnel.conf' else '%(levelname)s %(asctime)s %(message)s'
LOGFORMAT = '%(levelname)s %(asctime)s %(message)s' )
# MAX Length of read buffer for proxyed requests # MAX Length of read buffer for proxyed requests
BUFFER_SIZE = 1024 * 16 BUFFER_SIZE: typing.Final[int] = 1024 * 16
# Handshake for conversation start # Handshake for conversation start
HANDSHAKE_V1 = b'\x5AMGB\xA5\x01\x00' HANDSHAKE_V1: typing.Final[bytes] = b'\x5AMGB\xA5\x01\x00'
# Ticket length # Ticket length
TICKET_LENGTH = 48 TICKET_LENGTH: typing.Final[int] = 48
# Max Admin password length (stats basically right now) # Max Admin password length (stats basically right now)
PASSWORD_LENGTH = 64 PASSWORD_LENGTH: typing.Final[int] = 64
# Bandwidth calc time lapse # Bandwidth calc time lapse
BANDWIDTH_TIME = 10 BANDWIDTH_TIME: typing.Final[int] = 10
# Commands LENGTH (all same length) # Commands LENGTH (all same length)
COMMAND_LENGTH = 4 COMMAND_LENGTH: typing.Final[int] = 4
VERSION = 'v2.0.0' VERSION: typing.Final[str] = 'v2.0.0'
# Valid commands # Valid commands
COMMAND_OPEN = b'OPEN' COMMAND_OPEN: typing.Final[bytes] = b'OPEN'
COMMAND_TEST = b'TEST' COMMAND_TEST: typing.Final[bytes] = b'TEST'
COMMAND_STAT = b'STAT' # full stats COMMAND_STAT: typing.Final[bytes] = b'STAT' # full stats
COMMAND_INFO = b'INFO' # Basic stats, currently same as FULL COMMAND_INFO: typing.Final[bytes] = b'INFO' # Basic stats, currently same as FULL
RESPONSE_ERROR_TICKET: typing.Final[bytes] = b'ERROR_TICKET'
RESPONSE_ERROR_COMMAND: typing.Final[bytes] = b'ERROR_COMMAND'
RESPONSE_ERROR_TIMEOUT: typing.Final[bytes] = b'TIMEOUT'
RESPONSE_FORBIDDEN: typing.Final[bytes] = b'FORBIDDEN'
RESPONSE_OK: typing.Final[bytes] = b'OK'
# Timeout for command
TIMEOUT_COMMAND: typing.Final[int] = 3
# Backlog for listen socket
BACKLOG = 1024

View File

@ -156,13 +156,16 @@ class Processes:
ns: 'Namespace', ns: 'Namespace',
) -> None: ) -> None:
if cfg.use_uvloop: if cfg.use_uvloop:
import uvloop try:
import uvloop
if sys.version_info >= (3, 11): if sys.version_info >= (3, 11):
with asyncio.Runner(loop_factory=uvloop.new_event_loop) as runner: with asyncio.Runner(loop_factory=uvloop.new_event_loop) as runner:
runner.run(proc(conn, cfg, ns)) runner.run(proc(conn, cfg, ns))
else: else:
uvloop.install() uvloop.install()
asyncio.run(proc(conn, cfg, ns)) asyncio.run(proc(conn, cfg, ns))
except ImportError:
logger.warning('uvloop not found, using default asyncio')
else: else:
asyncio.run(proc(conn, cfg, ns)) asyncio.run(proc(conn, cfg, ns))

View File

@ -46,6 +46,7 @@ logger = logging.getLogger(__name__)
class Proxy: class Proxy:
cfg: 'config.ConfigurationType' cfg: 'config.ConfigurationType'
ns: 'Namespace' ns: 'Namespace'
finished: asyncio.Event
def __init__(self, cfg: 'config.ConfigurationType', ns: 'Namespace') -> None: def __init__(self, cfg: 'config.ConfigurationType', ns: 'Namespace') -> None:
self.cfg = cfg self.cfg = cfg
@ -63,22 +64,27 @@ class Proxy:
addr = source.getpeername() addr = source.getpeername()
except Exception: except Exception:
addr = 'Unknown' addr = 'Unknown'
logger.error('Proxy error from %s: %s', addr, e) logger.exception('Proxy error from %s: %s (%s--%s)', addr, e, source, context)
async def proxy(self, source: socket.socket, context: 'ssl.SSLContext') -> None: async def proxy(self, source: socket.socket, context: 'ssl.SSLContext') -> None:
loop = asyncio.get_event_loop() loop = asyncio.get_running_loop()
# Handshake correct in this point, upgrade the connection to TSL and let # Handshake correct in this point, upgrade the connection to TSL and let
# the protocol controller do the rest # the protocol controller do the rest
self.finished = asyncio.Event()
# Upgrade connection to SSL, and use asyncio to handle the rest # Upgrade connection to SSL, and use asyncio to handle the rest
try: try:
protocol: tunnel.TunnelProtocol def factory() -> tunnel.TunnelProtocol:
# (connect accepted loop not present on AbastractEventLoop definition < 3.10) return tunnel.TunnelProtocol(self)
(_, protocol) = await loop.connect_accepted_socket( # type: ignore # (connect accepted loop not present on AbastractEventLoop definition < 3.10), that's why we use ignore
lambda: tunnel.TunnelProtocol(self), source, ssl=context await loop.connect_accepted_socket( # type: ignore
factory, source, ssl=context
) )
await protocol.finished # Wait for connection to be closed
await self.finished.wait()
except asyncio.CancelledError: except asyncio.CancelledError:
pass # Return on cancel pass # Return on cancel

View File

@ -31,6 +31,7 @@ Author: Adolfo Gómez, dkmaster at dkmon dot com
import asyncio import asyncio
import typing import typing
import logging import logging
import socket
import aiohttp import aiohttp
@ -46,8 +47,6 @@ if typing.TYPE_CHECKING:
# Protocol # Protocol
class TunnelProtocol(asyncio.Protocol): class TunnelProtocol(asyncio.Protocol):
# future to mark eof
finished: asyncio.Future
# Transport and other side of tunnel # Transport and other side of tunnel
transport: 'asyncio.transports.Transport' transport: 'asyncio.transports.Transport'
other_side: 'TunnelProtocol' other_side: 'TunnelProtocol'
@ -56,7 +55,7 @@ class TunnelProtocol(asyncio.Protocol):
# Command buffer # Command buffer
cmd: bytes cmd: bytes
# Ticket # Ticket
notify_ticket: bytes notify_ticket: bytes # Only exists on "slave" transport (that is, tunnel from us to remote machine)
# owner Proxy class # owner Proxy class
owner: 'proxy.Proxy' owner: 'proxy.Proxy'
# source of connection # source of connection
@ -66,6 +65,8 @@ class TunnelProtocol(asyncio.Protocol):
stats_manager: stats.Stats stats_manager: stats.Stats
# counter # counter
counter: stats.StatsSingleCounter counter: stats.StatsSingleCounter
# If there is a timeout task running
timeout_task: typing.Optional[asyncio.Task] = None
def __init__( def __init__(
self, owner: 'proxy.Proxy', other_side: typing.Optional['TunnelProtocol'] = None self, owner: 'proxy.Proxy', other_side: typing.Optional['TunnelProtocol'] = None
@ -82,19 +83,23 @@ class TunnelProtocol(asyncio.Protocol):
self.stats_manager = stats.Stats(owner.ns) self.stats_manager = stats.Stats(owner.ns)
self.counter = self.stats_manager.as_sent_counter() self.counter = self.stats_manager.as_sent_counter()
self.runner = self.do_command self.runner = self.do_command
# Set starting timeout task, se we dont get hunged on connections without data
self.set_timeout(consts.TIMEOUT_COMMAND)
# transport is undefined until connection_made is called # transport is undefined until connection_made is called
self.finished = asyncio.Future()
self.cmd = b'' self.cmd = b''
self.notify_ticket = b'' self.notify_ticket = b''
self.owner = owner self.owner = owner
self.source = ('', 0) self.source = ('', 0)
self.destination = ('', 0) self.destination = ('', 0)
def process_open(self) -> None: def process_open(self) -> None:
# Open Command has the ticket behind it # Open Command has the ticket behind it
if len(self.cmd) < consts.TICKET_LENGTH + consts.COMMAND_LENGTH: if len(self.cmd) < consts.TICKET_LENGTH + consts.COMMAND_LENGTH:
# Reactivate timeout, will be deactivated on do_command
self.set_timeout(consts.TIMEOUT_COMMAND)
return # Wait for more data to complete OPEN command return # Wait for more data to complete OPEN command
# Ticket received, now process it with UDS # Ticket received, now process it with UDS
@ -106,7 +111,7 @@ class TunnelProtocol(asyncio.Protocol):
# clean up the command # clean up the command
self.cmd = b'' self.cmd = b''
loop = asyncio.get_event_loop() loop = asyncio.get_running_loop()
async def open_other_side() -> None: async def open_other_side() -> None:
try: try:
@ -115,7 +120,7 @@ class TunnelProtocol(asyncio.Protocol):
) )
except Exception as e: except Exception as e:
logger.error('ERROR %s', e.args[0] if e.args else e) logger.error('ERROR %s', e.args[0] if e.args else e)
self.transport.write(b'ERROR_TICKET') self.transport.write(consts.RESPONSE_ERROR_TICKET)
self.transport.close() # And force close self.transport.close() # And force close
return return
@ -130,10 +135,12 @@ class TunnelProtocol(asyncio.Protocol):
) )
try: try:
family = socket.AF_INET6 if ':' in self.destination[0] or self.owner.cfg.ipv6 else socket.AF_INET
(_, protocol) = await loop.create_connection( (_, protocol) = await loop.create_connection(
lambda: TunnelProtocol(self.owner, self), lambda: TunnelProtocol(self.owner, self),
self.destination[0], self.destination[0],
self.destination[1], self.destination[1],
family=family,
) )
self.other_side = typing.cast('TunnelProtocol', protocol) self.other_side = typing.cast('TunnelProtocol', protocol)
@ -145,6 +152,7 @@ class TunnelProtocol(asyncio.Protocol):
logger.error('Error opening connection: %s', e) logger.error('Error opening connection: %s', e)
self.close_connection() self.close_connection()
# add open other side to the loop
loop.create_task(open_other_side()) loop.create_task(open_other_side())
# From now, proxy connection # From now, proxy connection
self.runner = self.do_proxy self.runner = self.do_proxy
@ -160,7 +168,7 @@ class TunnelProtocol(asyncio.Protocol):
# Check valid source ip # Check valid source ip
if self.transport.get_extra_info('peername')[0] not in self.owner.cfg.allow: if self.transport.get_extra_info('peername')[0] not in self.owner.cfg.allow:
# Invalid source # Invalid source
self.transport.write(b'FORBIDDEN') self.transport.write(consts.RESPONSE_FORBIDDEN)
return return
# Check password # Check password
@ -171,7 +179,7 @@ class TunnelProtocol(asyncio.Protocol):
if passwd.decode(errors='ignore') != self.owner.cfg.secret: if passwd.decode(errors='ignore') != self.owner.cfg.secret:
# Invalid password # Invalid password
self.transport.write(b'FORBIDDEN') self.transport.write(consts.RESPONSE_FORBIDDEN)
return return
data = stats.GlobalStats.get_stats(self.owner.ns) data = stats.GlobalStats.get_stats(self.owner.ns)
@ -184,18 +192,50 @@ class TunnelProtocol(asyncio.Protocol):
finally: finally:
self.close_connection() self.close_connection()
async def timeout(self, wait: int) -> None:
""" Timeout can only occur while waiting for a command."""
try:
await asyncio.sleep(wait)
logger.error('TIMEOUT FROM %s', self.pretty_source())
self.transport.write(consts.RESPONSE_ERROR_TIMEOUT)
self.close_connection()
except asyncio.CancelledError:
pass
def set_timeout(self, wait: int) -> None:
"""Set a timeout for this connection.
If reached, the connection will be closed.
Args:
wait (int): Timeout in seconds
"""
if self.timeout_task:
self.timeout_task.cancel()
self.timeout_task = asyncio.create_task(self.timeout(wait))
def clean_timeout(self) -> None:
"""Clean the timeout task if any.
"""
if self.timeout_task:
self.timeout_task.cancel()
self.timeout_task = None
def do_command(self, data: bytes) -> None: def do_command(self, data: bytes) -> None:
self.cmd += data if self.cmd == b'':
if len(self.cmd) >= consts.COMMAND_LENGTH:
logger.info('CONNECT FROM %s', self.pretty_source()) logger.info('CONNECT FROM %s', self.pretty_source())
self.clean_timeout()
self.cmd += data
# Ensure we don't get a timeout
if len(self.cmd) >= consts.COMMAND_LENGTH:
command = self.cmd[: consts.COMMAND_LENGTH] command = self.cmd[: consts.COMMAND_LENGTH]
try: try:
if command == consts.COMMAND_OPEN: if command == consts.COMMAND_OPEN:
self.process_open() self.process_open()
elif command == consts.COMMAND_TEST: elif command == consts.COMMAND_TEST:
logger.info('COMMAND: TEST') logger.info('COMMAND: TEST')
self.transport.write(b'OK') self.transport.write(consts.RESPONSE_OK)
self.close_connection() self.close_connection()
return return
elif command in (consts.COMMAND_STAT, consts.COMMAND_INFO): elif command in (consts.COMMAND_STAT, consts.COMMAND_INFO):
@ -206,9 +246,11 @@ class TunnelProtocol(asyncio.Protocol):
raise Exception('Invalid command') raise Exception('Invalid command')
except Exception: except Exception:
logger.error('ERROR from %s', self.pretty_source()) logger.error('ERROR from %s', self.pretty_source())
self.transport.write(b'ERROR_COMMAND') self.transport.write(consts.RESPONSE_ERROR_COMMAND)
self.close_connection() self.close_connection()
return return
else:
self.set_timeout(consts.TIMEOUT_COMMAND)
# if not enough data to process command, wait for more # if not enough data to process command, wait for more
@ -221,6 +263,7 @@ class TunnelProtocol(asyncio.Protocol):
def connection_made(self, transport: 'asyncio.transports.BaseTransport') -> None: def connection_made(self, transport: 'asyncio.transports.BaseTransport') -> None:
logger.debug('Connection made: %s', transport.get_extra_info('peername')) logger.debug('Connection made: %s', transport.get_extra_info('peername'))
self.main = True # This is the main connection
# We know for sure that the transport is a Transport. # We know for sure that the transport is a Transport.
self.transport = typing.cast('asyncio.transports.Transport', transport) self.transport = typing.cast('asyncio.transports.Transport', transport)
@ -239,10 +282,12 @@ class TunnelProtocol(asyncio.Protocol):
) )
) )
self.notify_ticket = b'' # Clean up so no more notifications self.notify_ticket = b'' # Clean up so no more notifications
else: # No ticket, this is "main" connection (from client to us). Notify owner that we are done
self.owner.finished.set()
def connection_lost(self, exc: typing.Optional[Exception]) -> None: def connection_lost(self, exc: typing.Optional[Exception]) -> None:
logger.debug('Connection closed : %s', exc) logger.debug('Connection closed : %s', exc)
self.finished.set_result(True) # Ensure close other side if any
if self.other_side is not self: if self.other_side is not self:
self.other_side.transport.close() self.other_side.transport.close()
else: else:
@ -250,12 +295,17 @@ class TunnelProtocol(asyncio.Protocol):
self.notifyEnd() self.notifyEnd()
# helpers # helpers
@staticmethod
def pretty_address(address: typing.Tuple[str, int]) -> str:
if ':' in address[0]:
return '[' + address[0] + ']:' + str(address[1])
return address[0] + ':' + str(address[1])
# source address, pretty format # source address, pretty format
def pretty_source(self) -> str: def pretty_source(self) -> str:
return self.source[0] + ':' + str(self.source[1]) return TunnelProtocol.pretty_address(self.source)
def pretty_destination(self) -> str: def pretty_destination(self) -> str:
return self.destination[0] + ':' + str(self.destination[1]) return TunnelProtocol.pretty_address(self.destination)
def close_connection(self): def close_connection(self):
self.transport.close() self.transport.close()

View File

@ -20,14 +20,21 @@ lognumber = 3
# Listen address. Defaults to 0.0.0.0 # Listen address. Defaults to 0.0.0.0
address = 0.0.0.0 address = 0.0.0.0
# Number of workers. Defaults to 0 (means "as much as cores")
workers = 2
# Listening port # Listening port
port = 7777 port = 7777
# If force ipv6, defaults to false
# Note: if listen address is an ipv6 address, this will be forced to true
# This will force dns resolution to ipv6
ipv6 = false
# Number of workers. Defaults to 0 (means "as much as cores")
workers = 2
# SSL Related parameters. # SSL Related parameters.
ssl_certificate = /etc/certs/server.pem ssl_certificate = /etc/certs/server.pem
# Key can be included on certificate file, so this is optional
ssl_certificate_key = /etc/certs/key.pem ssl_certificate_key = /etc/certs/key.pem
# ssl_ciphers and ssl_dhparam are optional. # ssl_ciphers and ssl_dhparam are optional.
ssl_ciphers = ECDHE-RSA-AES256-GCM-SHA512:DHE-RSA-AES256-GCM-SHA512:ECDHE-RSA-AES256-GCM-SHA384:DHE-RSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-SHA384 ssl_ciphers = ECDHE-RSA-AES256-GCM-SHA512:DHE-RSA-AES256-GCM-SHA512:ECDHE-RSA-AES256-GCM-SHA384:DHE-RSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-SHA384
@ -40,6 +47,11 @@ ssl_dhparam = /etc/certs/dhparam.pem
# https://www.example.com:14333/uds/rest/tunnel/ticket # https://www.example.com:14333/uds/rest/tunnel/ticket
uds_server = http://172.27.0.1:8000/uds/rest/tunnel/ticket uds_server = http://172.27.0.1:8000/uds/rest/tunnel/ticket
uds_token = eBCeFxTBw1IKXCqq-RlncshwWIfrrqxc8y5nehqiqMtRztwD uds_token = eBCeFxTBw1IKXCqq-RlncshwWIfrrqxc8y5nehqiqMtRztwD
# Defaults to 10 seconds
# uds_timeout = 10
# If verify ssl certificate on uds server. Defaults to true
# uds_verify_ssl = true
# Secret to get access to admin commands (Currently only stats commands). No default for this. # Secret to get access to admin commands (Currently only stats commands). No default for this.
# Admin commands and only allowed from "allow" ips # Admin commands and only allowed from "allow" ips
@ -50,3 +62,7 @@ secret = MySecret
# Only use IPs, no networks allowed # Only use IPs, no networks allowed
# defaults to localhost (change if listen address is different from 0.0.0.0) # defaults to localhost (change if listen address is different from 0.0.0.0)
allow = 127.0.0.1 allow = 127.0.0.1
# If use uvloop as event loop. Defaults to true
# use_uvloop = true

View File

@ -39,6 +39,8 @@ import ssl
import socket import socket
import logging import logging
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
# event for stop notification
import threading
import typing import typing
try: try:
@ -59,16 +61,15 @@ if typing.TYPE_CHECKING:
from multiprocessing.connection import Connection from multiprocessing.connection import Connection
from multiprocessing.managers import Namespace from multiprocessing.managers import Namespace
BACKLOG = 1024
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
do_stop = False running: threading.Event = threading.Event()
def stop_signal(signum: int, frame: typing.Any) -> None: def stop_signal(signum: int, frame: typing.Any) -> None:
global do_stop global running
do_stop = True running.clear()
logger.debug('SIGNAL %s, frame: %s', signum, frame) logger.debug('SIGNAL %s, frame: %s', signum, frame)
@ -76,26 +77,26 @@ def setup_log(cfg: config.ConfigurationType) -> None:
from logging.handlers import RotatingFileHandler from logging.handlers import RotatingFileHandler
# Update logging if needed # Update logging if needed
if cfg.log_file: if cfg.logfile:
fileh = RotatingFileHandler( fileh = RotatingFileHandler(
filename=cfg.log_file, filename=cfg.logfile,
mode='a', mode='a',
maxBytes=cfg.log_size, maxBytes=cfg.logsize,
backupCount=cfg.log_number, backupCount=cfg.lognumber,
) )
formatter = logging.Formatter(consts.LOGFORMAT) formatter = logging.Formatter(consts.LOGFORMAT)
fileh.setFormatter(formatter) fileh.setFormatter(formatter)
log = logging.getLogger() log = logging.getLogger()
log.setLevel(cfg.log_level) log.setLevel(cfg.loglevel)
# for hdlr in log.handlers[:]: # for hdlr in log.handlers[:]:
# log.removeHandler(hdlr) # log.removeHandler(hdlr)
log.addHandler(fileh) log.addHandler(fileh)
else: else:
# Setup basic logging # Setup basic logging
log = logging.getLogger() log = logging.getLogger()
log.setLevel(cfg.log_level) log.setLevel(cfg.loglevel)
handler = logging.StreamHandler(sys.stderr) handler = logging.StreamHandler(sys.stderr)
handler.setLevel(cfg.log_level) handler.setLevel(cfg.loglevel)
formatter = logging.Formatter( formatter = logging.Formatter(
'%(levelname)s - %(message)s' '%(levelname)s - %(message)s'
) # Basic log format, nice for syslog ) # Basic log format, nice for syslog
@ -107,10 +108,7 @@ async def tunnel_proc_async(
pipe: 'Connection', cfg: config.ConfigurationType, ns: 'Namespace' pipe: 'Connection', cfg: config.ConfigurationType, ns: 'Namespace'
) -> None: ) -> None:
try: loop = asyncio.get_running_loop()
loop = asyncio.get_running_loop()
except RuntimeError: # older python versions
loop = asyncio.get_event_loop()
tasks: typing.List[asyncio.Task] = [] tasks: typing.List[asyncio.Task] = []
@ -123,8 +121,13 @@ async def tunnel_proc_async(
] = pipe.recv() ] = pipe.recv()
if msg: if msg:
return msg return msg
except EOFError:
logger.debug('Parent process closed connection')
pipe.close()
return None, None
except Exception: except Exception:
logger.exception('Receiving data from parent process') logger.exception('Receiving data from parent process')
pipe.close()
return None, None return None, None
async def run_server() -> None: async def run_server() -> None:
@ -133,7 +136,15 @@ async def tunnel_proc_async(
# Generate SSL context # Generate SSL context
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
context.load_cert_chain(cfg.ssl_certificate, cfg.ssl_certificate_key) args: typing.Dict[str, typing.Any] = {
'certfile': cfg.ssl_certificate,
}
if cfg.ssl_certificate_key:
args['keyfile'] = cfg.ssl_certificate_key
if cfg.ssl_password:
args['password'] = cfg.ssl_password
context.load_cert_chain(**args)
if cfg.ssl_ciphers: if cfg.ssl_ciphers:
context.set_ciphers(cfg.ssl_ciphers) context.set_ciphers(cfg.ssl_ciphers)
@ -141,29 +152,44 @@ async def tunnel_proc_async(
if cfg.ssl_dhparam: if cfg.ssl_dhparam:
context.load_dh_params(cfg.ssl_dhparam) context.load_dh_params(cfg.ssl_dhparam)
while True: try:
address: typing.Optional[typing.Tuple[str, int]] = ('', 0) while True:
try: address: typing.Optional[typing.Tuple[str, int]] = ('', 0)
(sock, address) = await loop.run_in_executor(None, get_socket) try:
if not sock: (sock, address) = await loop.run_in_executor(None, get_socket)
break # No more sockets, exit if not sock:
logger.debug(f'CONNECTION from {address!r} (pid: {os.getpid()})') break # No more sockets, exit
tasks.append(asyncio.create_task(tunneler(sock, context))) logger.debug(f'CONNECTION from {address!r} (pid: {os.getpid()})')
except Exception: tasks.append(asyncio.create_task(tunneler(sock, context)))
logger.error('NEGOTIATION ERROR from %s', address[0] if address else 'unknown') except asyncio.CancelledError:
raise
except Exception:
logger.error('NEGOTIATION ERROR from %s', address[0] if address else 'unknown')
except asyncio.CancelledError:
pass # Stop
# create task for server # create task for server
tasks.append(asyncio.create_task(run_server())) tasks.append(asyncio.create_task(run_server()))
while tasks: try:
to_wait = tasks[:] # Get a copy of the list, and clean the original while tasks and running.is_set():
# Wait for tasks to finish to_wait = tasks[:] # Get a copy of the list, and clean the original
done, _ = await asyncio.wait(to_wait, return_when=asyncio.FIRST_COMPLETED) # Wait for tasks to finish
# Remove finished tasks done, _ = await asyncio.wait(to_wait, return_when=asyncio.FIRST_COMPLETED, timeout=2)
for task in done: # Remove finished tasks
tasks.remove(task) for task in done:
if task.exception(): tasks.remove(task)
logger.exception('TUNNEL ERROR') if task.exception():
logger.exception('TUNNEL ERROR')
except asyncio.CancelledError:
running.clear() # ensure we stop
# If any task is still running, cancel it
for task in tasks:
task.cancel()
# Wait for all tasks to finish
await asyncio.gather(*tasks, return_exceptions=True)
logger.info('PROCESS %s stopped', os.getpid()) logger.info('PROCESS %s stopped', os.getpid())
@ -205,11 +231,11 @@ def tunnel_main(args: 'argparse.Namespace') -> None:
# logger.warning('socket.REUSEPORT not available') # logger.warning('socket.REUSEPORT not available')
try: try:
sock.bind((cfg.listen_address, cfg.listen_port)) sock.bind((cfg.listen_address, cfg.listen_port))
sock.listen(BACKLOG) sock.listen(consts.BACKLOG)
# If running as root, and requested drop privileges after port bind # If running as root, and requested drop privileges after port bind
if os.getuid() == 0 and cfg.user: if os.getuid() == 0 and cfg.user:
logger.debug('Changing to user %s', cfg.user) logger.debug('Changing to user %s', cfg.user)
pwu = pwd.getpwnam(cfg.user) pwu = pwd.getpwnam(cfg.user)
# os.setgid(pwu.pw_gid) # os.setgid(pwu.pw_gid)
os.setuid(pwu.pw_uid) os.setuid(pwu.pw_uid)
@ -233,16 +259,22 @@ def tunnel_main(args: 'argparse.Namespace') -> None:
return return
# Setup signal handlers # Setup signal handlers
signal.signal(signal.SIGINT, stop_signal) try:
signal.signal(signal.SIGTERM, stop_signal) signal.signal(signal.SIGINT, stop_signal)
signal.signal(signal.SIGTERM, stop_signal)
except Exception as e:
# Signal not available on threads, and we use threads on tests,
# so we will ignore this because on tests signals are not important
logger.warning('Signal not available: %s', e)
stats_collector = stats.GlobalStats() stats_collector = stats.GlobalStats()
prcs = processes.Processes(tunnel_proc_async, cfg, stats_collector.ns) prcs = processes.Processes(tunnel_proc_async, cfg, stats_collector.ns)
running.set() # Signal we are running
with ThreadPoolExecutor(max_workers=256) as executor: with ThreadPoolExecutor(max_workers=256) as executor:
try: try:
while not do_stop: while running.is_set():
try: try:
client, addr = sock.accept() client, addr = sock.accept()
logger.info('CONNECTION from %s', addr) logger.info('CONNECTION from %s', addr)