1
0
mirror of https://github.com/dkmstr/openuds.git synced 2025-01-11 05:17:55 +03:00

Advancing on new tunneler

This commit is contained in:
Adolfo Gómez García 2021-01-13 10:04:26 +01:00
parent e486d6708d
commit 971e5984d9
4 changed files with 102 additions and 185 deletions

View File

@ -143,6 +143,7 @@ class Handler(socketserver.BaseRequestHandler):
if not data:
break
self.request.sendall(data)
logger.debug('Finished process')
except Exception as e:
pass
@ -173,8 +174,8 @@ def forward(
if __name__ == "__main__":
fs1 = forward(('fake.udsenterprise.com', 7777), '0'*64, local_port=49998)
print(f'Listening on {fs1.server_address}')
#fs2 = forward(('fake.udsenterprise.com', 7777), '1'*64, local_port=49999)
#print(f'Listening on {fs2.server_address}')
fs2 = forward(('fake.udsenterprise.com', 7777), '1'*64, local_port=49999)
print(f'Listening on {fs2.server_address}')
# time.sleep(30)
# fs.stop()

View File

@ -38,15 +38,18 @@ from . import config
from . import stats
from . import consts
if typing.TYPE_CHECKING:
from multiprocessing.managers import Namespace
logger = logging.getLogger(__name__)
class Proxy:
cfg: config.ConfigurationType
stat: stats.Stats
ns: 'Namespace'
def __init__(self, cfg: config.ConfigurationType) -> None:
def __init__(self, cfg: config.ConfigurationType, ns: 'Namespace') -> None:
self.cfg = cfg
self.stat = stats.Stats()
self.ns = ns
@staticmethod
def getFromUds(cfg: config.ConfigurationType, ticket: bytes) -> typing.MutableMapping[str, typing.Any]:
@ -77,6 +80,10 @@ class Proxy:
await destination.sendall(data)
counter.add(len(data))
# Method responsible of proxying requests
async def __call__(self, source, address: typing.Tuple[str, int]) -> None:
await self.proxy(source, address)
async def stats(self, full: bool, source, address: typing.Tuple[str, int]) -> None:
# Check valid source ip
if address[0] not in self.cfg.allow:
@ -93,17 +100,12 @@ class Proxy:
logger.info('STATS TO %s', address)
if full:
data = self.stat.full_as_csv()
else:
data = self.stat.simple_as_csv()
data = stats.GlobalStats.get_stats(self.ns)
async for v in data:
for v in data:
logger.debug('SENDING %s', v)
await source.sendall(v.encode() + b'\n')
# Method responsible of proxying requests
async def __call__(self, source, address: typing.Tuple[str, int]) -> None:
await self.proxy(source, address)
async def proxy(self, source, address: typing.Tuple[str, int]) -> None:
logger.info('OPEN FROM %s', address)
@ -119,7 +121,6 @@ class Proxy:
logger.exception('HANDSHAKE')
logger.error('HANDSHAKE from %s', address)
await source.sendall(b'HANDSHAKE_ERROR')
# Closes connection now
return
@ -166,7 +167,7 @@ class Proxy:
await source.sendall(b'OK')
# Initialize own stats counter
counter = await self.stat.new()
counter = stats.Stats(self.ns)
# Open remote server connection
try:
@ -184,8 +185,7 @@ class Proxy:
logger.error('REMOTE from %s: %s', address, e)
finally:
await counter.close()
counter.close()
logger.info('CLOSED FROM %s', address)
logger.info('STATS: %s', counter.as_csv())

View File

@ -28,30 +28,30 @@
'''
@author: Adolfo Gómez, dkmaster at dkmon dot com
'''
import multiprocessing
import time
import logging
import typing
import io
import ssl
import logging
import typing
import curio
import blist
from . import config
from . import consts
if typing.TYPE_CHECKING:
from multiprocessing.managers import Namespace, SyncManager
INTERVAL = 2 # Interval in seconds between stats update
logger = logging.getLogger(__name__)
# Locker for id assigner
assignLock = curio.Lock()
# Tuple index for several stats
SENT, RECV = 0, 1
# Subclasses for += operation to work
class StatsSingleCounter:
def __init__(self, parent: 'StatsConnection', for_receiving=True) -> None:
def __init__(self, parent: 'Stats', for_receiving=True) -> None:
if for_receiving:
self.adder = parent.add_recv
else:
@ -62,56 +62,34 @@ class StatsSingleCounter:
return self
class StatsConnection:
id: int
recv: int
class Stats:
ns: 'Namespace'
sent: int
start_time: int
parent: 'Stats'
recv: int
last: float
# Bandwidth stats (SENT, RECV)
last: typing.List[int]
last_time: typing.List[float]
def __init__(self, ns: 'Namespace'):
self.ns = ns
self.ns.current += 1
self.ns.total += 1
self.sent = 0
self.recv = 0
self.last = time.monotonic()
bandwidth: typing.List[int]
max_bandwidth: typing.List[int]
def __init__(self, parent: 'Stats', id: int) -> None:
self.id = id
self.recv = self.sent = 0
now = time.time()
self.start_time = int(now)
self.parent = parent
self.last = [0, 0]
self.last_time = [now, now]
self.bandwidth = [0, 0]
self.max_bandwidth = [0, 0]
def update_bandwidth(self, kind: int, counter: int):
now = time.time()
elapsed = now - self.last_time[kind]
# Update only when enouth data
if elapsed < consts.BANDWIDTH_TIME:
return
total = counter - self.last[kind]
self.bandwidth[kind] = int(float(total) / elapsed)
self.last[kind] = counter
self.last_time[kind] = now
if self.bandwidth[kind] > self.max_bandwidth[kind]:
self.max_bandwidth[kind] = self.bandwidth[kind]
def update(self, force: bool = False):
now = time.monotonic()
if force or now - self.last > INTERVAL:
self.last = now
self.ns.recv = self.recv
self.ns.sent = self.sent
def add_recv(self, size: int) -> None:
self.recv += size
self.update_bandwidth(RECV, counter=self.recv)
self.parent.add_recv(size)
self.update()
def add_sent(self, size: int) -> None:
self.sent += size
self.update_bandwidth(SENT, counter=self.sent)
self.parent.add_sent(size)
self.update()
def as_sent_counter(self) -> 'StatsSingleCounter':
return StatsSingleCounter(self, False)
@ -119,114 +97,34 @@ class StatsConnection:
def as_recv_counter(self) -> 'StatsSingleCounter':
return StatsSingleCounter(self, True)
async def close(self) -> None:
if self.id:
logger.debug(f'STAT {self.id} closed')
await self.parent.remove(self.id)
self.id = 0
def close(self):
self.update(True)
self.ns.current -= 1
def as_csv(self, separator: typing.Optional[str] = None) -> str:
separator = separator or ';'
# With connections of less than a second, consider them as a second
elapsed = (int(time.time()) - self.start_time)
# Stats collector thread
class GlobalStats:
manager: 'SyncManager'
ns: 'Namespace'
counter: int
return separator.join(
str(i)
for i in (
self.id,
self.start_time,
elapsed,
self.sent,
self.bandwidth[SENT],
self.max_bandwidth[SENT],
self.recv,
self.bandwidth[RECV],
self.max_bandwidth[RECV],
)
)
def __init__(self):
super().__init__()
self.manager = multiprocessing.Manager()
self.ns = self.manager.Namespace()
def __str__(self) -> str:
return f'{self.id} t:{int(time.time())-self.start_time}, r:{self.recv}, s:{self.sent}>'
# Counters
self.ns.current = 0
self.ns.total = 0
self.ns.sent = 0
self.ns.recv = 0
self.counter = 0
# For sorted array
def __lt__(self, other) -> bool:
if isinstance(other, int):
return self.id < other
if not isinstance(other, StatsConnection):
raise NotImplemented
return self.id < other.id
def __eq__(self, other) -> bool:
if isinstance(other, int):
return self.id == other
if not isinstance(other, StatsConnection):
raise NotImplemented
return self.id == other.id
class Stats:
counter_id: int
total_sent: int
total_received: int
current_connections: blist.sortedlist
def __init__(self) -> None:
# First connection will be 1
self.counter_id = 0
self.total_sent = self.total_received = 0
self.current_connections = blist.sortedlist()
async def new(self) -> StatsConnection:
"""Initializes a connection stats counter and returns it id
Returns:
str: connection id
"""
async with assignLock:
self.counter_id += 1
connection = StatsConnection(self, self.counter_id)
self.current_connections.add(connection)
return connection
def add_sent(self, size: int) -> None:
self.total_sent += size
def add_recv(self, size: int) -> None:
self.total_received += size
async def remove(self, connection_id: int) -> None:
async with assignLock:
try:
self.current_connections.remove(connection_id)
except Exception:
logger.debug(
'Tried to remove %s from connections but was not present',
connection_id,
)
# Does not exists, ignore it
pass
async def simple_as_csv(self, separator: typing.Optional[str] = None) -> typing.AsyncIterable[str]:
separator = separator or ';'
yield separator.join(
str(i)
for i in (
self.counter_id,
self.total_sent,
self.total_received,
len(self.current_connections),
)
)
async def full_as_csv(self, separator: typing.Optional[str] = None) -> typing.AsyncIterable[str]:
for i in self.current_connections:
yield i.as_csv(separator)
def info(self) -> typing.Iterable[str]:
return GlobalStats.get_stats(self.ns)
@staticmethod
def get_stats(ns: 'Namespace') -> typing.Iterable[str]:
yield ';'.join([str(ns.current), str(ns.total), str(ns.sent), str(ns.recv)])
# Stats processor, invoked from command line
async def getServerStats(detailed: bool = False) -> None:

View File

@ -1,7 +1,7 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright (c) 2020 Virtual Cable S.L.U.
# Copyright (c) 2021 Virtual Cable S.L.U.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
@ -32,6 +32,7 @@
import sys
import argparse
import multiprocessing
import threading
import socket
import logging
import typing
@ -47,6 +48,7 @@ from uds_tunnel import stats
if typing.TYPE_CHECKING:
from multiprocessing.connection import Connection
from multiprocessing.managers import Namespace
BACKLOG = 100
@ -76,7 +78,9 @@ def setup_log(cfg: config.ConfigurationType) -> None:
log.addHandler(fileh)
async def tunnel_proc_async(pipe: 'Connection', cfg: config.ConfigurationType) -> None:
async def tunnel_proc_async(
pipe: 'Connection', cfg: config.ConfigurationType, ns: 'Namespace'
) -> None:
def get_socket(pipe: 'Connection') -> typing.Tuple[socket.SocketType, typing.Any]:
try:
while True:
@ -92,7 +96,7 @@ async def tunnel_proc_async(pipe: 'Connection', cfg: config.ConfigurationType) -
pipe: 'Connection', cfg: config.ConfigurationType, group: curio.TaskGroup
) -> None:
# Instantiate a proxy redirector for this process (we only need one per process!!)
tunneler = proxy.Proxy(cfg)
tunneler = proxy.Proxy(cfg, ns)
# Generate SSL context
context = curio.ssl.SSLContext(curio.ssl.PROTOCOL_TLS_SERVER)
@ -105,15 +109,18 @@ async def tunnel_proc_async(pipe: 'Connection', cfg: config.ConfigurationType) -
context.load_dh_params(cfg.ssl_dhparam)
while True:
sock, address = await curio.run_in_thread(get_socket, pipe)
if not sock:
break
logger.debug(
f'{multiprocessing.current_process().pid!r}: Got new connection from {address!r}'
)
sock = await context.wrap_socket(curio.io.Socket(sock), server_side=True)
await group.spawn(tunneler, sock, address)
del sock
try:
sock, address = await curio.run_in_thread(get_socket, pipe)
if not sock:
break
logger.debug(
f'{multiprocessing.current_process().pid!r}: Got new connection from {address!r}'
)
sock = await context.wrap_socket(curio.io.Socket(sock), server_side=True)
await group.spawn(tunneler, sock, address)
del sock
except Exception as e:
logger.error('SETING UP CONNECTION: %s', e)
async with curio.TaskGroup() as tg:
await tg.spawn(run_server, pipe, cfg, tg)
@ -133,10 +140,13 @@ def tunnel_main():
typing.Tuple['Connection', multiprocessing.Process, psutil.Process]
] = []
stats_collector = stats.GlobalStats()
for i in range(cfg.workers):
own_conn, child_conn = multiprocessing.Pipe()
task = multiprocessing.Process(
target=curio.run, args=(tunnel_proc_async, child_conn, cfg)
target=curio.run,
args=(tunnel_proc_async, child_conn, cfg, stats_collector.ns),
)
task.start()
child.append((own_conn, task, psutil.Process(task.pid)))
@ -166,10 +176,13 @@ def tunnel_main():
client, addr = sock.accept()
# Select BEST process for sending this new connection
best_child().send(message.Message(message.Command.TUNNEL, (client, addr)))
del client # Ensure socket is controlled on child process
except Exception:
logger.exception('Mar')
pass
logger.info('Exiting tunnel server')
if sock:
sock.close()
@ -196,7 +209,12 @@ def main() -> None:
if args.tunnel:
tunnel_main()
parser.print_help()
elif args.detailed_stats:
curio.run(stats.getServerStats, True)
elif args.stats:
curio.run(stats.getServerStats, False)
else:
parser.print_help()
if __name__ == "__main__":