mirror of
https://github.com/dkmstr/openuds.git
synced 2024-12-23 17:34:17 +03:00
Advancing on new tunneler
This commit is contained in:
parent
e486d6708d
commit
971e5984d9
@ -143,6 +143,7 @@ class Handler(socketserver.BaseRequestHandler):
|
||||
if not data:
|
||||
break
|
||||
self.request.sendall(data)
|
||||
logger.debug('Finished process')
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
@ -173,8 +174,8 @@ def forward(
|
||||
if __name__ == "__main__":
|
||||
fs1 = forward(('fake.udsenterprise.com', 7777), '0'*64, local_port=49998)
|
||||
print(f'Listening on {fs1.server_address}')
|
||||
#fs2 = forward(('fake.udsenterprise.com', 7777), '1'*64, local_port=49999)
|
||||
#print(f'Listening on {fs2.server_address}')
|
||||
fs2 = forward(('fake.udsenterprise.com', 7777), '1'*64, local_port=49999)
|
||||
print(f'Listening on {fs2.server_address}')
|
||||
# time.sleep(30)
|
||||
# fs.stop()
|
||||
|
||||
|
@ -38,15 +38,18 @@ from . import config
|
||||
from . import stats
|
||||
from . import consts
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from multiprocessing.managers import Namespace
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class Proxy:
|
||||
cfg: config.ConfigurationType
|
||||
stat: stats.Stats
|
||||
ns: 'Namespace'
|
||||
|
||||
def __init__(self, cfg: config.ConfigurationType) -> None:
|
||||
def __init__(self, cfg: config.ConfigurationType, ns: 'Namespace') -> None:
|
||||
self.cfg = cfg
|
||||
self.stat = stats.Stats()
|
||||
self.ns = ns
|
||||
|
||||
@staticmethod
|
||||
def getFromUds(cfg: config.ConfigurationType, ticket: bytes) -> typing.MutableMapping[str, typing.Any]:
|
||||
@ -77,6 +80,10 @@ class Proxy:
|
||||
await destination.sendall(data)
|
||||
counter.add(len(data))
|
||||
|
||||
# Method responsible of proxying requests
|
||||
async def __call__(self, source, address: typing.Tuple[str, int]) -> None:
|
||||
await self.proxy(source, address)
|
||||
|
||||
async def stats(self, full: bool, source, address: typing.Tuple[str, int]) -> None:
|
||||
# Check valid source ip
|
||||
if address[0] not in self.cfg.allow:
|
||||
@ -93,17 +100,12 @@ class Proxy:
|
||||
|
||||
logger.info('STATS TO %s', address)
|
||||
|
||||
if full:
|
||||
data = self.stat.full_as_csv()
|
||||
else:
|
||||
data = self.stat.simple_as_csv()
|
||||
data = stats.GlobalStats.get_stats(self.ns)
|
||||
|
||||
async for v in data:
|
||||
for v in data:
|
||||
logger.debug('SENDING %s', v)
|
||||
await source.sendall(v.encode() + b'\n')
|
||||
|
||||
# Method responsible of proxying requests
|
||||
async def __call__(self, source, address: typing.Tuple[str, int]) -> None:
|
||||
await self.proxy(source, address)
|
||||
|
||||
async def proxy(self, source, address: typing.Tuple[str, int]) -> None:
|
||||
logger.info('OPEN FROM %s', address)
|
||||
@ -119,7 +121,6 @@ class Proxy:
|
||||
logger.exception('HANDSHAKE')
|
||||
logger.error('HANDSHAKE from %s', address)
|
||||
await source.sendall(b'HANDSHAKE_ERROR')
|
||||
|
||||
# Closes connection now
|
||||
return
|
||||
|
||||
@ -166,7 +167,7 @@ class Proxy:
|
||||
await source.sendall(b'OK')
|
||||
|
||||
# Initialize own stats counter
|
||||
counter = await self.stat.new()
|
||||
counter = stats.Stats(self.ns)
|
||||
|
||||
# Open remote server connection
|
||||
try:
|
||||
@ -184,8 +185,7 @@ class Proxy:
|
||||
|
||||
logger.error('REMOTE from %s: %s', address, e)
|
||||
finally:
|
||||
await counter.close()
|
||||
counter.close()
|
||||
|
||||
|
||||
logger.info('CLOSED FROM %s', address)
|
||||
logger.info('STATS: %s', counter.as_csv())
|
||||
|
@ -28,30 +28,30 @@
|
||||
'''
|
||||
@author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
'''
|
||||
import multiprocessing
|
||||
import time
|
||||
import logging
|
||||
import typing
|
||||
import io
|
||||
import ssl
|
||||
import logging
|
||||
import typing
|
||||
|
||||
import curio
|
||||
import blist
|
||||
|
||||
from . import config
|
||||
from . import consts
|
||||
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from multiprocessing.managers import Namespace, SyncManager
|
||||
|
||||
INTERVAL = 2 # Interval in seconds between stats update
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Locker for id assigner
|
||||
assignLock = curio.Lock()
|
||||
|
||||
# Tuple index for several stats
|
||||
SENT, RECV = 0, 1
|
||||
|
||||
# Subclasses for += operation to work
|
||||
class StatsSingleCounter:
|
||||
def __init__(self, parent: 'StatsConnection', for_receiving=True) -> None:
|
||||
def __init__(self, parent: 'Stats', for_receiving=True) -> None:
|
||||
if for_receiving:
|
||||
self.adder = parent.add_recv
|
||||
else:
|
||||
@ -62,56 +62,34 @@ class StatsSingleCounter:
|
||||
return self
|
||||
|
||||
|
||||
class StatsConnection:
|
||||
id: int
|
||||
recv: int
|
||||
class Stats:
|
||||
ns: 'Namespace'
|
||||
sent: int
|
||||
start_time: int
|
||||
parent: 'Stats'
|
||||
recv: int
|
||||
last: float
|
||||
|
||||
# Bandwidth stats (SENT, RECV)
|
||||
last: typing.List[int]
|
||||
last_time: typing.List[float]
|
||||
def __init__(self, ns: 'Namespace'):
|
||||
self.ns = ns
|
||||
self.ns.current += 1
|
||||
self.ns.total += 1
|
||||
self.sent = 0
|
||||
self.recv = 0
|
||||
self.last = time.monotonic()
|
||||
|
||||
bandwidth: typing.List[int]
|
||||
max_bandwidth: typing.List[int]
|
||||
|
||||
def __init__(self, parent: 'Stats', id: int) -> None:
|
||||
self.id = id
|
||||
self.recv = self.sent = 0
|
||||
|
||||
now = time.time()
|
||||
self.start_time = int(now)
|
||||
self.parent = parent
|
||||
|
||||
self.last = [0, 0]
|
||||
self.last_time = [now, now]
|
||||
self.bandwidth = [0, 0]
|
||||
self.max_bandwidth = [0, 0]
|
||||
|
||||
def update_bandwidth(self, kind: int, counter: int):
|
||||
now = time.time()
|
||||
elapsed = now - self.last_time[kind]
|
||||
# Update only when enouth data
|
||||
if elapsed < consts.BANDWIDTH_TIME:
|
||||
return
|
||||
total = counter - self.last[kind]
|
||||
self.bandwidth[kind] = int(float(total) / elapsed)
|
||||
self.last[kind] = counter
|
||||
self.last_time[kind] = now
|
||||
|
||||
if self.bandwidth[kind] > self.max_bandwidth[kind]:
|
||||
self.max_bandwidth[kind] = self.bandwidth[kind]
|
||||
def update(self, force: bool = False):
|
||||
now = time.monotonic()
|
||||
if force or now - self.last > INTERVAL:
|
||||
self.last = now
|
||||
self.ns.recv = self.recv
|
||||
self.ns.sent = self.sent
|
||||
|
||||
def add_recv(self, size: int) -> None:
|
||||
self.recv += size
|
||||
self.update_bandwidth(RECV, counter=self.recv)
|
||||
self.parent.add_recv(size)
|
||||
self.update()
|
||||
|
||||
def add_sent(self, size: int) -> None:
|
||||
self.sent += size
|
||||
self.update_bandwidth(SENT, counter=self.sent)
|
||||
self.parent.add_sent(size)
|
||||
self.update()
|
||||
|
||||
def as_sent_counter(self) -> 'StatsSingleCounter':
|
||||
return StatsSingleCounter(self, False)
|
||||
@ -119,114 +97,34 @@ class StatsConnection:
|
||||
def as_recv_counter(self) -> 'StatsSingleCounter':
|
||||
return StatsSingleCounter(self, True)
|
||||
|
||||
async def close(self) -> None:
|
||||
if self.id:
|
||||
logger.debug(f'STAT {self.id} closed')
|
||||
await self.parent.remove(self.id)
|
||||
self.id = 0
|
||||
def close(self):
|
||||
self.update(True)
|
||||
self.ns.current -= 1
|
||||
|
||||
def as_csv(self, separator: typing.Optional[str] = None) -> str:
|
||||
separator = separator or ';'
|
||||
# With connections of less than a second, consider them as a second
|
||||
elapsed = (int(time.time()) - self.start_time)
|
||||
# Stats collector thread
|
||||
class GlobalStats:
|
||||
manager: 'SyncManager'
|
||||
ns: 'Namespace'
|
||||
counter: int
|
||||
|
||||
return separator.join(
|
||||
str(i)
|
||||
for i in (
|
||||
self.id,
|
||||
self.start_time,
|
||||
elapsed,
|
||||
self.sent,
|
||||
self.bandwidth[SENT],
|
||||
self.max_bandwidth[SENT],
|
||||
self.recv,
|
||||
self.bandwidth[RECV],
|
||||
self.max_bandwidth[RECV],
|
||||
)
|
||||
)
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.manager = multiprocessing.Manager()
|
||||
self.ns = self.manager.Namespace()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'{self.id} t:{int(time.time())-self.start_time}, r:{self.recv}, s:{self.sent}>'
|
||||
# Counters
|
||||
self.ns.current = 0
|
||||
self.ns.total = 0
|
||||
self.ns.sent = 0
|
||||
self.ns.recv = 0
|
||||
self.counter = 0
|
||||
|
||||
# For sorted array
|
||||
def __lt__(self, other) -> bool:
|
||||
if isinstance(other, int):
|
||||
return self.id < other
|
||||
|
||||
if not isinstance(other, StatsConnection):
|
||||
raise NotImplemented
|
||||
|
||||
return self.id < other.id
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
if isinstance(other, int):
|
||||
return self.id == other
|
||||
|
||||
if not isinstance(other, StatsConnection):
|
||||
raise NotImplemented
|
||||
|
||||
return self.id == other.id
|
||||
|
||||
|
||||
class Stats:
|
||||
counter_id: int
|
||||
|
||||
total_sent: int
|
||||
total_received: int
|
||||
current_connections: blist.sortedlist
|
||||
|
||||
def __init__(self) -> None:
|
||||
# First connection will be 1
|
||||
self.counter_id = 0
|
||||
self.total_sent = self.total_received = 0
|
||||
self.current_connections = blist.sortedlist()
|
||||
|
||||
async def new(self) -> StatsConnection:
|
||||
"""Initializes a connection stats counter and returns it id
|
||||
|
||||
Returns:
|
||||
str: connection id
|
||||
"""
|
||||
async with assignLock:
|
||||
self.counter_id += 1
|
||||
connection = StatsConnection(self, self.counter_id)
|
||||
self.current_connections.add(connection)
|
||||
return connection
|
||||
|
||||
def add_sent(self, size: int) -> None:
|
||||
self.total_sent += size
|
||||
|
||||
def add_recv(self, size: int) -> None:
|
||||
self.total_received += size
|
||||
|
||||
async def remove(self, connection_id: int) -> None:
|
||||
async with assignLock:
|
||||
try:
|
||||
self.current_connections.remove(connection_id)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
'Tried to remove %s from connections but was not present',
|
||||
connection_id,
|
||||
)
|
||||
# Does not exists, ignore it
|
||||
pass
|
||||
|
||||
async def simple_as_csv(self, separator: typing.Optional[str] = None) -> typing.AsyncIterable[str]:
|
||||
separator = separator or ';'
|
||||
yield separator.join(
|
||||
str(i)
|
||||
for i in (
|
||||
self.counter_id,
|
||||
self.total_sent,
|
||||
self.total_received,
|
||||
len(self.current_connections),
|
||||
)
|
||||
)
|
||||
|
||||
async def full_as_csv(self, separator: typing.Optional[str] = None) -> typing.AsyncIterable[str]:
|
||||
for i in self.current_connections:
|
||||
yield i.as_csv(separator)
|
||||
def info(self) -> typing.Iterable[str]:
|
||||
return GlobalStats.get_stats(self.ns)
|
||||
|
||||
@staticmethod
|
||||
def get_stats(ns: 'Namespace') -> typing.Iterable[str]:
|
||||
yield ';'.join([str(ns.current), str(ns.total), str(ns.sent), str(ns.recv)])
|
||||
|
||||
# Stats processor, invoked from command line
|
||||
async def getServerStats(detailed: bool = False) -> None:
|
||||
|
@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright (c) 2020 Virtual Cable S.L.U.
|
||||
# Copyright (c) 2021 Virtual Cable S.L.U.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification,
|
||||
@ -32,6 +32,7 @@
|
||||
import sys
|
||||
import argparse
|
||||
import multiprocessing
|
||||
import threading
|
||||
import socket
|
||||
import logging
|
||||
import typing
|
||||
@ -47,6 +48,7 @@ from uds_tunnel import stats
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from multiprocessing.connection import Connection
|
||||
from multiprocessing.managers import Namespace
|
||||
|
||||
BACKLOG = 100
|
||||
|
||||
@ -76,7 +78,9 @@ def setup_log(cfg: config.ConfigurationType) -> None:
|
||||
log.addHandler(fileh)
|
||||
|
||||
|
||||
async def tunnel_proc_async(pipe: 'Connection', cfg: config.ConfigurationType) -> None:
|
||||
async def tunnel_proc_async(
|
||||
pipe: 'Connection', cfg: config.ConfigurationType, ns: 'Namespace'
|
||||
) -> None:
|
||||
def get_socket(pipe: 'Connection') -> typing.Tuple[socket.SocketType, typing.Any]:
|
||||
try:
|
||||
while True:
|
||||
@ -92,7 +96,7 @@ async def tunnel_proc_async(pipe: 'Connection', cfg: config.ConfigurationType) -
|
||||
pipe: 'Connection', cfg: config.ConfigurationType, group: curio.TaskGroup
|
||||
) -> None:
|
||||
# Instantiate a proxy redirector for this process (we only need one per process!!)
|
||||
tunneler = proxy.Proxy(cfg)
|
||||
tunneler = proxy.Proxy(cfg, ns)
|
||||
|
||||
# Generate SSL context
|
||||
context = curio.ssl.SSLContext(curio.ssl.PROTOCOL_TLS_SERVER)
|
||||
@ -105,15 +109,18 @@ async def tunnel_proc_async(pipe: 'Connection', cfg: config.ConfigurationType) -
|
||||
context.load_dh_params(cfg.ssl_dhparam)
|
||||
|
||||
while True:
|
||||
sock, address = await curio.run_in_thread(get_socket, pipe)
|
||||
if not sock:
|
||||
break
|
||||
logger.debug(
|
||||
f'{multiprocessing.current_process().pid!r}: Got new connection from {address!r}'
|
||||
)
|
||||
sock = await context.wrap_socket(curio.io.Socket(sock), server_side=True)
|
||||
await group.spawn(tunneler, sock, address)
|
||||
del sock
|
||||
try:
|
||||
sock, address = await curio.run_in_thread(get_socket, pipe)
|
||||
if not sock:
|
||||
break
|
||||
logger.debug(
|
||||
f'{multiprocessing.current_process().pid!r}: Got new connection from {address!r}'
|
||||
)
|
||||
sock = await context.wrap_socket(curio.io.Socket(sock), server_side=True)
|
||||
await group.spawn(tunneler, sock, address)
|
||||
del sock
|
||||
except Exception as e:
|
||||
logger.error('SETING UP CONNECTION: %s', e)
|
||||
|
||||
async with curio.TaskGroup() as tg:
|
||||
await tg.spawn(run_server, pipe, cfg, tg)
|
||||
@ -133,10 +140,13 @@ def tunnel_main():
|
||||
typing.Tuple['Connection', multiprocessing.Process, psutil.Process]
|
||||
] = []
|
||||
|
||||
stats_collector = stats.GlobalStats()
|
||||
|
||||
for i in range(cfg.workers):
|
||||
own_conn, child_conn = multiprocessing.Pipe()
|
||||
task = multiprocessing.Process(
|
||||
target=curio.run, args=(tunnel_proc_async, child_conn, cfg)
|
||||
target=curio.run,
|
||||
args=(tunnel_proc_async, child_conn, cfg, stats_collector.ns),
|
||||
)
|
||||
task.start()
|
||||
child.append((own_conn, task, psutil.Process(task.pid)))
|
||||
@ -166,10 +176,13 @@ def tunnel_main():
|
||||
client, addr = sock.accept()
|
||||
# Select BEST process for sending this new connection
|
||||
best_child().send(message.Message(message.Command.TUNNEL, (client, addr)))
|
||||
|
||||
del client # Ensure socket is controlled on child process
|
||||
except Exception:
|
||||
logger.exception('Mar')
|
||||
pass
|
||||
|
||||
logger.info('Exiting tunnel server')
|
||||
|
||||
if sock:
|
||||
sock.close()
|
||||
|
||||
@ -196,7 +209,12 @@ def main() -> None:
|
||||
|
||||
if args.tunnel:
|
||||
tunnel_main()
|
||||
parser.print_help()
|
||||
elif args.detailed_stats:
|
||||
curio.run(stats.getServerStats, True)
|
||||
elif args.stats:
|
||||
curio.run(stats.getServerStats, False)
|
||||
else:
|
||||
parser.print_help()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
Reference in New Issue
Block a user