mirror of
https://github.com/dkmstr/openuds.git
synced 2024-12-22 13:34:04 +03:00
updates from 4.0 backported
This commit is contained in:
parent
b7962a24f1
commit
adeb6b2a46
@ -126,9 +126,7 @@ async def main():
|
||||
data = client.recv(4)
|
||||
print(data)
|
||||
# Upgrade connection to SSL, and use asyncio to handle the rest
|
||||
transport: 'asyncio.transports.Transport'
|
||||
protocol: TunnelProtocol
|
||||
(transport, protocol) = await loop.connect_accepted_socket( # type: ignore
|
||||
(_, protocol) = await loop.connect_accepted_socket( # type: ignore
|
||||
lambda: TunnelProtocol(), client, ssl=context
|
||||
)
|
||||
|
||||
|
@ -43,18 +43,21 @@ class ConfigurationType(typing.NamedTuple):
|
||||
pidfile: str
|
||||
user: str
|
||||
|
||||
log_level: str
|
||||
log_file: str
|
||||
log_size: int
|
||||
log_number: int
|
||||
loglevel: str
|
||||
logfile: str
|
||||
logsize: int
|
||||
lognumber: int
|
||||
|
||||
listen_address: str
|
||||
listen_port: int
|
||||
|
||||
ipv6: bool
|
||||
|
||||
workers: int
|
||||
|
||||
ssl_certificate: str
|
||||
ssl_certificate_key: str
|
||||
ssl_password: str
|
||||
ssl_ciphers: str
|
||||
ssl_dhparam: str
|
||||
|
||||
@ -111,15 +114,17 @@ def read(
|
||||
return ConfigurationType(
|
||||
pidfile=uds.get('pidfile', ''),
|
||||
user=uds.get('user', ''),
|
||||
log_level=uds.get('loglevel', 'ERROR'),
|
||||
log_file=uds.get('logfile', ''),
|
||||
log_size=int(logsize) * 1024 * 1024,
|
||||
log_number=int(uds.get('lognumber', '3')),
|
||||
loglevel=uds.get('loglevel', 'ERROR'),
|
||||
logfile=uds.get('logfile', ''),
|
||||
logsize=int(logsize) * 1024 * 1024,
|
||||
lognumber=int(uds.get('lognumber', '3')),
|
||||
listen_address=uds.get('address', '0.0.0.0'),
|
||||
listen_port=int(uds.get('port', '443')),
|
||||
ipv6=uds.get('ipv6', 'false').lower() == 'true',
|
||||
workers=int(uds.get('workers', '0')) or multiprocessing.cpu_count(),
|
||||
ssl_certificate=uds['ssl_certificate'],
|
||||
ssl_certificate_key=uds['ssl_certificate_key'],
|
||||
ssl_certificate_key=uds.get('ssl_certificate_key', ''),
|
||||
ssl_password=uds.get('ssl_password', ''),
|
||||
ssl_ciphers=uds.get('ssl_ciphers'),
|
||||
ssl_dhparam=uds.get('ssl_dhparam'),
|
||||
uds_server=uds_server,
|
||||
|
@ -28,34 +28,48 @@
|
||||
'''
|
||||
Author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
'''
|
||||
import typing
|
||||
|
||||
DEBUG = True
|
||||
|
||||
if DEBUG:
|
||||
CONFIGFILE = 'udstunnel.conf'
|
||||
LOGFORMAT = '%(levelname)s %(asctime)s %(message)s'
|
||||
else:
|
||||
CONFIGFILE = '/etc/udstunnel.conf'
|
||||
LOGFORMAT = '%(levelname)s %(asctime)s %(message)s'
|
||||
CONFIGFILE: typing.Final[str] = '/etc/udstunnel.conf' if not DEBUG else 'udstunnel.conf'
|
||||
LOGFORMAT: typing.Final[str] = (
|
||||
'%(levelname)s %(asctime)s %(message)s'
|
||||
if not DEBUG
|
||||
else '%(levelname)s %(asctime)s %(message)s'
|
||||
)
|
||||
|
||||
# MAX Length of read buffer for proxyed requests
|
||||
BUFFER_SIZE = 1024 * 16
|
||||
BUFFER_SIZE: typing.Final[int] = 1024 * 16
|
||||
# Handshake for conversation start
|
||||
HANDSHAKE_V1 = b'\x5AMGB\xA5\x01\x00'
|
||||
HANDSHAKE_V1: typing.Final[bytes] = b'\x5AMGB\xA5\x01\x00'
|
||||
# Ticket length
|
||||
TICKET_LENGTH = 48
|
||||
TICKET_LENGTH: typing.Final[int] = 48
|
||||
# Max Admin password length (stats basically right now)
|
||||
PASSWORD_LENGTH = 64
|
||||
PASSWORD_LENGTH: typing.Final[int] = 64
|
||||
# Bandwidth calc time lapse
|
||||
BANDWIDTH_TIME = 10
|
||||
BANDWIDTH_TIME: typing.Final[int] = 10
|
||||
|
||||
# Commands LENGTH (all same length)
|
||||
COMMAND_LENGTH = 4
|
||||
COMMAND_LENGTH: typing.Final[int] = 4
|
||||
|
||||
VERSION = 'v2.0.0'
|
||||
VERSION: typing.Final[str] = 'v2.0.0'
|
||||
|
||||
# Valid commands
|
||||
COMMAND_OPEN = b'OPEN'
|
||||
COMMAND_TEST = b'TEST'
|
||||
COMMAND_STAT = b'STAT' # full stats
|
||||
COMMAND_INFO = b'INFO' # Basic stats, currently same as FULL
|
||||
COMMAND_OPEN: typing.Final[bytes] = b'OPEN'
|
||||
COMMAND_TEST: typing.Final[bytes] = b'TEST'
|
||||
COMMAND_STAT: typing.Final[bytes] = b'STAT' # full stats
|
||||
COMMAND_INFO: typing.Final[bytes] = b'INFO' # Basic stats, currently same as FULL
|
||||
|
||||
RESPONSE_ERROR_TICKET: typing.Final[bytes] = b'ERROR_TICKET'
|
||||
RESPONSE_ERROR_COMMAND: typing.Final[bytes] = b'ERROR_COMMAND'
|
||||
RESPONSE_ERROR_TIMEOUT: typing.Final[bytes] = b'TIMEOUT'
|
||||
RESPONSE_FORBIDDEN: typing.Final[bytes] = b'FORBIDDEN'
|
||||
|
||||
RESPONSE_OK: typing.Final[bytes] = b'OK'
|
||||
|
||||
# Timeout for command
|
||||
TIMEOUT_COMMAND: typing.Final[int] = 3
|
||||
|
||||
# Backlog for listen socket
|
||||
BACKLOG = 1024
|
||||
|
@ -156,13 +156,16 @@ class Processes:
|
||||
ns: 'Namespace',
|
||||
) -> None:
|
||||
if cfg.use_uvloop:
|
||||
import uvloop
|
||||
try:
|
||||
import uvloop
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
with asyncio.Runner(loop_factory=uvloop.new_event_loop) as runner:
|
||||
runner.run(proc(conn, cfg, ns))
|
||||
else:
|
||||
uvloop.install()
|
||||
asyncio.run(proc(conn, cfg, ns))
|
||||
if sys.version_info >= (3, 11):
|
||||
with asyncio.Runner(loop_factory=uvloop.new_event_loop) as runner:
|
||||
runner.run(proc(conn, cfg, ns))
|
||||
else:
|
||||
uvloop.install()
|
||||
asyncio.run(proc(conn, cfg, ns))
|
||||
except ImportError:
|
||||
logger.warning('uvloop not found, using default asyncio')
|
||||
else:
|
||||
asyncio.run(proc(conn, cfg, ns))
|
||||
|
@ -46,6 +46,7 @@ logger = logging.getLogger(__name__)
|
||||
class Proxy:
|
||||
cfg: 'config.ConfigurationType'
|
||||
ns: 'Namespace'
|
||||
finished: asyncio.Event
|
||||
|
||||
def __init__(self, cfg: 'config.ConfigurationType', ns: 'Namespace') -> None:
|
||||
self.cfg = cfg
|
||||
@ -63,22 +64,27 @@ class Proxy:
|
||||
addr = source.getpeername()
|
||||
except Exception:
|
||||
addr = 'Unknown'
|
||||
logger.error('Proxy error from %s: %s', addr, e)
|
||||
logger.exception('Proxy error from %s: %s (%s--%s)', addr, e, source, context)
|
||||
|
||||
async def proxy(self, source: socket.socket, context: 'ssl.SSLContext') -> None:
|
||||
loop = asyncio.get_event_loop()
|
||||
loop = asyncio.get_running_loop()
|
||||
# Handshake correct in this point, upgrade the connection to TSL and let
|
||||
# the protocol controller do the rest
|
||||
self.finished = asyncio.Event()
|
||||
|
||||
# Upgrade connection to SSL, and use asyncio to handle the rest
|
||||
try:
|
||||
protocol: tunnel.TunnelProtocol
|
||||
# (connect accepted loop not present on AbastractEventLoop definition < 3.10)
|
||||
(_, protocol) = await loop.connect_accepted_socket( # type: ignore
|
||||
lambda: tunnel.TunnelProtocol(self), source, ssl=context
|
||||
def factory() -> tunnel.TunnelProtocol:
|
||||
return tunnel.TunnelProtocol(self)
|
||||
# (connect accepted loop not present on AbastractEventLoop definition < 3.10), that's why we use ignore
|
||||
await loop.connect_accepted_socket( # type: ignore
|
||||
factory, source, ssl=context
|
||||
)
|
||||
|
||||
await protocol.finished
|
||||
# Wait for connection to be closed
|
||||
await self.finished.wait()
|
||||
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass # Return on cancel
|
||||
|
||||
|
@ -31,6 +31,7 @@ Author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
import asyncio
|
||||
import typing
|
||||
import logging
|
||||
import socket
|
||||
|
||||
import aiohttp
|
||||
|
||||
@ -46,8 +47,6 @@ if typing.TYPE_CHECKING:
|
||||
|
||||
# Protocol
|
||||
class TunnelProtocol(asyncio.Protocol):
|
||||
# future to mark eof
|
||||
finished: asyncio.Future
|
||||
# Transport and other side of tunnel
|
||||
transport: 'asyncio.transports.Transport'
|
||||
other_side: 'TunnelProtocol'
|
||||
@ -56,7 +55,7 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
# Command buffer
|
||||
cmd: bytes
|
||||
# Ticket
|
||||
notify_ticket: bytes
|
||||
notify_ticket: bytes # Only exists on "slave" transport (that is, tunnel from us to remote machine)
|
||||
# owner Proxy class
|
||||
owner: 'proxy.Proxy'
|
||||
# source of connection
|
||||
@ -66,6 +65,8 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
stats_manager: stats.Stats
|
||||
# counter
|
||||
counter: stats.StatsSingleCounter
|
||||
# If there is a timeout task running
|
||||
timeout_task: typing.Optional[asyncio.Task] = None
|
||||
|
||||
def __init__(
|
||||
self, owner: 'proxy.Proxy', other_side: typing.Optional['TunnelProtocol'] = None
|
||||
@ -82,19 +83,23 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
self.stats_manager = stats.Stats(owner.ns)
|
||||
self.counter = self.stats_manager.as_sent_counter()
|
||||
self.runner = self.do_command
|
||||
# Set starting timeout task, se we dont get hunged on connections without data
|
||||
self.set_timeout(consts.TIMEOUT_COMMAND)
|
||||
|
||||
# transport is undefined until connection_made is called
|
||||
self.finished = asyncio.Future()
|
||||
self.cmd = b''
|
||||
self.notify_ticket = b''
|
||||
self.owner = owner
|
||||
self.source = ('', 0)
|
||||
self.destination = ('', 0)
|
||||
|
||||
|
||||
def process_open(self) -> None:
|
||||
# Open Command has the ticket behind it
|
||||
|
||||
if len(self.cmd) < consts.TICKET_LENGTH + consts.COMMAND_LENGTH:
|
||||
# Reactivate timeout, will be deactivated on do_command
|
||||
self.set_timeout(consts.TIMEOUT_COMMAND)
|
||||
return # Wait for more data to complete OPEN command
|
||||
|
||||
# Ticket received, now process it with UDS
|
||||
@ -106,7 +111,7 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
# clean up the command
|
||||
self.cmd = b''
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
async def open_other_side() -> None:
|
||||
try:
|
||||
@ -115,7 +120,7 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error('ERROR %s', e.args[0] if e.args else e)
|
||||
self.transport.write(b'ERROR_TICKET')
|
||||
self.transport.write(consts.RESPONSE_ERROR_TICKET)
|
||||
self.transport.close() # And force close
|
||||
return
|
||||
|
||||
@ -130,10 +135,12 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
)
|
||||
|
||||
try:
|
||||
family = socket.AF_INET6 if ':' in self.destination[0] or self.owner.cfg.ipv6 else socket.AF_INET
|
||||
(_, protocol) = await loop.create_connection(
|
||||
lambda: TunnelProtocol(self.owner, self),
|
||||
self.destination[0],
|
||||
self.destination[1],
|
||||
family=family,
|
||||
)
|
||||
self.other_side = typing.cast('TunnelProtocol', protocol)
|
||||
|
||||
@ -145,6 +152,7 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
logger.error('Error opening connection: %s', e)
|
||||
self.close_connection()
|
||||
|
||||
# add open other side to the loop
|
||||
loop.create_task(open_other_side())
|
||||
# From now, proxy connection
|
||||
self.runner = self.do_proxy
|
||||
@ -160,7 +168,7 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
# Check valid source ip
|
||||
if self.transport.get_extra_info('peername')[0] not in self.owner.cfg.allow:
|
||||
# Invalid source
|
||||
self.transport.write(b'FORBIDDEN')
|
||||
self.transport.write(consts.RESPONSE_FORBIDDEN)
|
||||
return
|
||||
|
||||
# Check password
|
||||
@ -171,7 +179,7 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
|
||||
if passwd.decode(errors='ignore') != self.owner.cfg.secret:
|
||||
# Invalid password
|
||||
self.transport.write(b'FORBIDDEN')
|
||||
self.transport.write(consts.RESPONSE_FORBIDDEN)
|
||||
return
|
||||
|
||||
data = stats.GlobalStats.get_stats(self.owner.ns)
|
||||
@ -184,18 +192,50 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
finally:
|
||||
self.close_connection()
|
||||
|
||||
async def timeout(self, wait: int) -> None:
|
||||
""" Timeout can only occur while waiting for a command."""
|
||||
try:
|
||||
await asyncio.sleep(wait)
|
||||
logger.error('TIMEOUT FROM %s', self.pretty_source())
|
||||
self.transport.write(consts.RESPONSE_ERROR_TIMEOUT)
|
||||
self.close_connection()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
def set_timeout(self, wait: int) -> None:
|
||||
"""Set a timeout for this connection.
|
||||
If reached, the connection will be closed.
|
||||
|
||||
Args:
|
||||
wait (int): Timeout in seconds
|
||||
|
||||
"""
|
||||
if self.timeout_task:
|
||||
self.timeout_task.cancel()
|
||||
self.timeout_task = asyncio.create_task(self.timeout(wait))
|
||||
|
||||
def clean_timeout(self) -> None:
|
||||
"""Clean the timeout task if any.
|
||||
"""
|
||||
if self.timeout_task:
|
||||
self.timeout_task.cancel()
|
||||
self.timeout_task = None
|
||||
|
||||
def do_command(self, data: bytes) -> None:
|
||||
self.cmd += data
|
||||
if len(self.cmd) >= consts.COMMAND_LENGTH:
|
||||
if self.cmd == b'':
|
||||
logger.info('CONNECT FROM %s', self.pretty_source())
|
||||
|
||||
self.clean_timeout()
|
||||
self.cmd += data
|
||||
# Ensure we don't get a timeout
|
||||
if len(self.cmd) >= consts.COMMAND_LENGTH:
|
||||
command = self.cmd[: consts.COMMAND_LENGTH]
|
||||
try:
|
||||
if command == consts.COMMAND_OPEN:
|
||||
self.process_open()
|
||||
elif command == consts.COMMAND_TEST:
|
||||
logger.info('COMMAND: TEST')
|
||||
self.transport.write(b'OK')
|
||||
self.transport.write(consts.RESPONSE_OK)
|
||||
self.close_connection()
|
||||
return
|
||||
elif command in (consts.COMMAND_STAT, consts.COMMAND_INFO):
|
||||
@ -206,9 +246,11 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
raise Exception('Invalid command')
|
||||
except Exception:
|
||||
logger.error('ERROR from %s', self.pretty_source())
|
||||
self.transport.write(b'ERROR_COMMAND')
|
||||
self.transport.write(consts.RESPONSE_ERROR_COMMAND)
|
||||
self.close_connection()
|
||||
return
|
||||
else:
|
||||
self.set_timeout(consts.TIMEOUT_COMMAND)
|
||||
|
||||
# if not enough data to process command, wait for more
|
||||
|
||||
@ -221,6 +263,7 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
|
||||
def connection_made(self, transport: 'asyncio.transports.BaseTransport') -> None:
|
||||
logger.debug('Connection made: %s', transport.get_extra_info('peername'))
|
||||
self.main = True # This is the main connection
|
||||
|
||||
# We know for sure that the transport is a Transport.
|
||||
self.transport = typing.cast('asyncio.transports.Transport', transport)
|
||||
@ -239,10 +282,12 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
)
|
||||
)
|
||||
self.notify_ticket = b'' # Clean up so no more notifications
|
||||
else: # No ticket, this is "main" connection (from client to us). Notify owner that we are done
|
||||
self.owner.finished.set()
|
||||
|
||||
def connection_lost(self, exc: typing.Optional[Exception]) -> None:
|
||||
logger.debug('Connection closed : %s', exc)
|
||||
self.finished.set_result(True)
|
||||
# Ensure close other side if any
|
||||
if self.other_side is not self:
|
||||
self.other_side.transport.close()
|
||||
else:
|
||||
@ -250,12 +295,17 @@ class TunnelProtocol(asyncio.Protocol):
|
||||
self.notifyEnd()
|
||||
|
||||
# helpers
|
||||
@staticmethod
|
||||
def pretty_address(address: typing.Tuple[str, int]) -> str:
|
||||
if ':' in address[0]:
|
||||
return '[' + address[0] + ']:' + str(address[1])
|
||||
return address[0] + ':' + str(address[1])
|
||||
# source address, pretty format
|
||||
def pretty_source(self) -> str:
|
||||
return self.source[0] + ':' + str(self.source[1])
|
||||
return TunnelProtocol.pretty_address(self.source)
|
||||
|
||||
def pretty_destination(self) -> str:
|
||||
return self.destination[0] + ':' + str(self.destination[1])
|
||||
return TunnelProtocol.pretty_address(self.destination)
|
||||
|
||||
def close_connection(self):
|
||||
self.transport.close()
|
||||
|
@ -20,14 +20,21 @@ lognumber = 3
|
||||
# Listen address. Defaults to 0.0.0.0
|
||||
address = 0.0.0.0
|
||||
|
||||
# Number of workers. Defaults to 0 (means "as much as cores")
|
||||
workers = 2
|
||||
|
||||
# Listening port
|
||||
port = 7777
|
||||
|
||||
# If force ipv6, defaults to false
|
||||
# Note: if listen address is an ipv6 address, this will be forced to true
|
||||
# This will force dns resolution to ipv6
|
||||
ipv6 = false
|
||||
|
||||
# Number of workers. Defaults to 0 (means "as much as cores")
|
||||
workers = 2
|
||||
|
||||
|
||||
# SSL Related parameters.
|
||||
ssl_certificate = /etc/certs/server.pem
|
||||
# Key can be included on certificate file, so this is optional
|
||||
ssl_certificate_key = /etc/certs/key.pem
|
||||
# ssl_ciphers and ssl_dhparam are optional.
|
||||
ssl_ciphers = ECDHE-RSA-AES256-GCM-SHA512:DHE-RSA-AES256-GCM-SHA512:ECDHE-RSA-AES256-GCM-SHA384:DHE-RSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-SHA384
|
||||
@ -40,6 +47,11 @@ ssl_dhparam = /etc/certs/dhparam.pem
|
||||
# https://www.example.com:14333/uds/rest/tunnel/ticket
|
||||
uds_server = http://172.27.0.1:8000/uds/rest/tunnel/ticket
|
||||
uds_token = eBCeFxTBw1IKXCqq-RlncshwWIfrrqxc8y5nehqiqMtRztwD
|
||||
# Defaults to 10 seconds
|
||||
# uds_timeout = 10
|
||||
|
||||
# If verify ssl certificate on uds server. Defaults to true
|
||||
# uds_verify_ssl = true
|
||||
|
||||
# Secret to get access to admin commands (Currently only stats commands). No default for this.
|
||||
# Admin commands and only allowed from "allow" ips
|
||||
@ -50,3 +62,7 @@ secret = MySecret
|
||||
# Only use IPs, no networks allowed
|
||||
# defaults to localhost (change if listen address is different from 0.0.0.0)
|
||||
allow = 127.0.0.1
|
||||
|
||||
|
||||
# If use uvloop as event loop. Defaults to true
|
||||
# use_uvloop = true
|
@ -39,6 +39,8 @@ import ssl
|
||||
import socket
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
# event for stop notification
|
||||
import threading
|
||||
import typing
|
||||
|
||||
try:
|
||||
@ -59,16 +61,15 @@ if typing.TYPE_CHECKING:
|
||||
from multiprocessing.connection import Connection
|
||||
from multiprocessing.managers import Namespace
|
||||
|
||||
BACKLOG = 1024
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
do_stop = False
|
||||
running: threading.Event = threading.Event()
|
||||
|
||||
|
||||
def stop_signal(signum: int, frame: typing.Any) -> None:
|
||||
global do_stop
|
||||
do_stop = True
|
||||
global running
|
||||
running.clear()
|
||||
logger.debug('SIGNAL %s, frame: %s', signum, frame)
|
||||
|
||||
|
||||
@ -76,26 +77,26 @@ def setup_log(cfg: config.ConfigurationType) -> None:
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
# Update logging if needed
|
||||
if cfg.log_file:
|
||||
if cfg.logfile:
|
||||
fileh = RotatingFileHandler(
|
||||
filename=cfg.log_file,
|
||||
filename=cfg.logfile,
|
||||
mode='a',
|
||||
maxBytes=cfg.log_size,
|
||||
backupCount=cfg.log_number,
|
||||
maxBytes=cfg.logsize,
|
||||
backupCount=cfg.lognumber,
|
||||
)
|
||||
formatter = logging.Formatter(consts.LOGFORMAT)
|
||||
fileh.setFormatter(formatter)
|
||||
log = logging.getLogger()
|
||||
log.setLevel(cfg.log_level)
|
||||
log.setLevel(cfg.loglevel)
|
||||
# for hdlr in log.handlers[:]:
|
||||
# log.removeHandler(hdlr)
|
||||
log.addHandler(fileh)
|
||||
else:
|
||||
# Setup basic logging
|
||||
log = logging.getLogger()
|
||||
log.setLevel(cfg.log_level)
|
||||
log.setLevel(cfg.loglevel)
|
||||
handler = logging.StreamHandler(sys.stderr)
|
||||
handler.setLevel(cfg.log_level)
|
||||
handler.setLevel(cfg.loglevel)
|
||||
formatter = logging.Formatter(
|
||||
'%(levelname)s - %(message)s'
|
||||
) # Basic log format, nice for syslog
|
||||
@ -107,10 +108,7 @@ async def tunnel_proc_async(
|
||||
pipe: 'Connection', cfg: config.ConfigurationType, ns: 'Namespace'
|
||||
) -> None:
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError: # older python versions
|
||||
loop = asyncio.get_event_loop()
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
tasks: typing.List[asyncio.Task] = []
|
||||
|
||||
@ -123,8 +121,13 @@ async def tunnel_proc_async(
|
||||
] = pipe.recv()
|
||||
if msg:
|
||||
return msg
|
||||
except EOFError:
|
||||
logger.debug('Parent process closed connection')
|
||||
pipe.close()
|
||||
return None, None
|
||||
except Exception:
|
||||
logger.exception('Receiving data from parent process')
|
||||
pipe.close()
|
||||
return None, None
|
||||
|
||||
async def run_server() -> None:
|
||||
@ -133,7 +136,15 @@ async def tunnel_proc_async(
|
||||
|
||||
# Generate SSL context
|
||||
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
context.load_cert_chain(cfg.ssl_certificate, cfg.ssl_certificate_key)
|
||||
args: typing.Dict[str, typing.Any] = {
|
||||
'certfile': cfg.ssl_certificate,
|
||||
}
|
||||
if cfg.ssl_certificate_key:
|
||||
args['keyfile'] = cfg.ssl_certificate_key
|
||||
if cfg.ssl_password:
|
||||
args['password'] = cfg.ssl_password
|
||||
|
||||
context.load_cert_chain(**args)
|
||||
|
||||
if cfg.ssl_ciphers:
|
||||
context.set_ciphers(cfg.ssl_ciphers)
|
||||
@ -141,29 +152,44 @@ async def tunnel_proc_async(
|
||||
if cfg.ssl_dhparam:
|
||||
context.load_dh_params(cfg.ssl_dhparam)
|
||||
|
||||
while True:
|
||||
address: typing.Optional[typing.Tuple[str, int]] = ('', 0)
|
||||
try:
|
||||
(sock, address) = await loop.run_in_executor(None, get_socket)
|
||||
if not sock:
|
||||
break # No more sockets, exit
|
||||
logger.debug(f'CONNECTION from {address!r} (pid: {os.getpid()})')
|
||||
tasks.append(asyncio.create_task(tunneler(sock, context)))
|
||||
except Exception:
|
||||
logger.error('NEGOTIATION ERROR from %s', address[0] if address else 'unknown')
|
||||
try:
|
||||
while True:
|
||||
address: typing.Optional[typing.Tuple[str, int]] = ('', 0)
|
||||
try:
|
||||
(sock, address) = await loop.run_in_executor(None, get_socket)
|
||||
if not sock:
|
||||
break # No more sockets, exit
|
||||
logger.debug(f'CONNECTION from {address!r} (pid: {os.getpid()})')
|
||||
tasks.append(asyncio.create_task(tunneler(sock, context)))
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.error('NEGOTIATION ERROR from %s', address[0] if address else 'unknown')
|
||||
except asyncio.CancelledError:
|
||||
pass # Stop
|
||||
|
||||
# create task for server
|
||||
tasks.append(asyncio.create_task(run_server()))
|
||||
|
||||
while tasks:
|
||||
to_wait = tasks[:] # Get a copy of the list, and clean the original
|
||||
# Wait for tasks to finish
|
||||
done, _ = await asyncio.wait(to_wait, return_when=asyncio.FIRST_COMPLETED)
|
||||
# Remove finished tasks
|
||||
for task in done:
|
||||
tasks.remove(task)
|
||||
if task.exception():
|
||||
logger.exception('TUNNEL ERROR')
|
||||
try:
|
||||
while tasks and running.is_set():
|
||||
to_wait = tasks[:] # Get a copy of the list, and clean the original
|
||||
# Wait for tasks to finish
|
||||
done, _ = await asyncio.wait(to_wait, return_when=asyncio.FIRST_COMPLETED, timeout=2)
|
||||
# Remove finished tasks
|
||||
for task in done:
|
||||
tasks.remove(task)
|
||||
if task.exception():
|
||||
logger.exception('TUNNEL ERROR')
|
||||
except asyncio.CancelledError:
|
||||
running.clear() # ensure we stop
|
||||
|
||||
# If any task is still running, cancel it
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
|
||||
# Wait for all tasks to finish
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
logger.info('PROCESS %s stopped', os.getpid())
|
||||
|
||||
@ -205,11 +231,11 @@ def tunnel_main(args: 'argparse.Namespace') -> None:
|
||||
# logger.warning('socket.REUSEPORT not available')
|
||||
try:
|
||||
sock.bind((cfg.listen_address, cfg.listen_port))
|
||||
sock.listen(BACKLOG)
|
||||
sock.listen(consts.BACKLOG)
|
||||
|
||||
# If running as root, and requested drop privileges after port bind
|
||||
if os.getuid() == 0 and cfg.user:
|
||||
logger.debug('Changing to user %s', cfg.user)
|
||||
logger.debug('Changing to user %s', cfg.user)
|
||||
pwu = pwd.getpwnam(cfg.user)
|
||||
# os.setgid(pwu.pw_gid)
|
||||
os.setuid(pwu.pw_uid)
|
||||
@ -233,16 +259,22 @@ def tunnel_main(args: 'argparse.Namespace') -> None:
|
||||
return
|
||||
|
||||
# Setup signal handlers
|
||||
signal.signal(signal.SIGINT, stop_signal)
|
||||
signal.signal(signal.SIGTERM, stop_signal)
|
||||
try:
|
||||
signal.signal(signal.SIGINT, stop_signal)
|
||||
signal.signal(signal.SIGTERM, stop_signal)
|
||||
except Exception as e:
|
||||
# Signal not available on threads, and we use threads on tests,
|
||||
# so we will ignore this because on tests signals are not important
|
||||
logger.warning('Signal not available: %s', e)
|
||||
|
||||
stats_collector = stats.GlobalStats()
|
||||
|
||||
prcs = processes.Processes(tunnel_proc_async, cfg, stats_collector.ns)
|
||||
|
||||
running.set() # Signal we are running
|
||||
with ThreadPoolExecutor(max_workers=256) as executor:
|
||||
try:
|
||||
while not do_stop:
|
||||
while running.is_set():
|
||||
try:
|
||||
client, addr = sock.accept()
|
||||
logger.info('CONNECTION from %s', addr)
|
||||
|
Loading…
Reference in New Issue
Block a user