forked from shaba/openuds
Advancing on new tunneler
This commit is contained in:
parent
e486d6708d
commit
971e5984d9
@ -143,6 +143,7 @@ class Handler(socketserver.BaseRequestHandler):
|
|||||||
if not data:
|
if not data:
|
||||||
break
|
break
|
||||||
self.request.sendall(data)
|
self.request.sendall(data)
|
||||||
|
logger.debug('Finished process')
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -173,8 +174,8 @@ def forward(
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
fs1 = forward(('fake.udsenterprise.com', 7777), '0'*64, local_port=49998)
|
fs1 = forward(('fake.udsenterprise.com', 7777), '0'*64, local_port=49998)
|
||||||
print(f'Listening on {fs1.server_address}')
|
print(f'Listening on {fs1.server_address}')
|
||||||
#fs2 = forward(('fake.udsenterprise.com', 7777), '1'*64, local_port=49999)
|
fs2 = forward(('fake.udsenterprise.com', 7777), '1'*64, local_port=49999)
|
||||||
#print(f'Listening on {fs2.server_address}')
|
print(f'Listening on {fs2.server_address}')
|
||||||
# time.sleep(30)
|
# time.sleep(30)
|
||||||
# fs.stop()
|
# fs.stop()
|
||||||
|
|
||||||
|
@ -38,15 +38,18 @@ from . import config
|
|||||||
from . import stats
|
from . import stats
|
||||||
from . import consts
|
from . import consts
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
from multiprocessing.managers import Namespace
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class Proxy:
|
class Proxy:
|
||||||
cfg: config.ConfigurationType
|
cfg: config.ConfigurationType
|
||||||
stat: stats.Stats
|
ns: 'Namespace'
|
||||||
|
|
||||||
def __init__(self, cfg: config.ConfigurationType) -> None:
|
def __init__(self, cfg: config.ConfigurationType, ns: 'Namespace') -> None:
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.stat = stats.Stats()
|
self.ns = ns
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def getFromUds(cfg: config.ConfigurationType, ticket: bytes) -> typing.MutableMapping[str, typing.Any]:
|
def getFromUds(cfg: config.ConfigurationType, ticket: bytes) -> typing.MutableMapping[str, typing.Any]:
|
||||||
@ -77,6 +80,10 @@ class Proxy:
|
|||||||
await destination.sendall(data)
|
await destination.sendall(data)
|
||||||
counter.add(len(data))
|
counter.add(len(data))
|
||||||
|
|
||||||
|
# Method responsible of proxying requests
|
||||||
|
async def __call__(self, source, address: typing.Tuple[str, int]) -> None:
|
||||||
|
await self.proxy(source, address)
|
||||||
|
|
||||||
async def stats(self, full: bool, source, address: typing.Tuple[str, int]) -> None:
|
async def stats(self, full: bool, source, address: typing.Tuple[str, int]) -> None:
|
||||||
# Check valid source ip
|
# Check valid source ip
|
||||||
if address[0] not in self.cfg.allow:
|
if address[0] not in self.cfg.allow:
|
||||||
@ -93,17 +100,12 @@ class Proxy:
|
|||||||
|
|
||||||
logger.info('STATS TO %s', address)
|
logger.info('STATS TO %s', address)
|
||||||
|
|
||||||
if full:
|
data = stats.GlobalStats.get_stats(self.ns)
|
||||||
data = self.stat.full_as_csv()
|
|
||||||
else:
|
|
||||||
data = self.stat.simple_as_csv()
|
|
||||||
|
|
||||||
async for v in data:
|
for v in data:
|
||||||
|
logger.debug('SENDING %s', v)
|
||||||
await source.sendall(v.encode() + b'\n')
|
await source.sendall(v.encode() + b'\n')
|
||||||
|
|
||||||
# Method responsible of proxying requests
|
|
||||||
async def __call__(self, source, address: typing.Tuple[str, int]) -> None:
|
|
||||||
await self.proxy(source, address)
|
|
||||||
|
|
||||||
async def proxy(self, source, address: typing.Tuple[str, int]) -> None:
|
async def proxy(self, source, address: typing.Tuple[str, int]) -> None:
|
||||||
logger.info('OPEN FROM %s', address)
|
logger.info('OPEN FROM %s', address)
|
||||||
@ -119,7 +121,6 @@ class Proxy:
|
|||||||
logger.exception('HANDSHAKE')
|
logger.exception('HANDSHAKE')
|
||||||
logger.error('HANDSHAKE from %s', address)
|
logger.error('HANDSHAKE from %s', address)
|
||||||
await source.sendall(b'HANDSHAKE_ERROR')
|
await source.sendall(b'HANDSHAKE_ERROR')
|
||||||
|
|
||||||
# Closes connection now
|
# Closes connection now
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -166,7 +167,7 @@ class Proxy:
|
|||||||
await source.sendall(b'OK')
|
await source.sendall(b'OK')
|
||||||
|
|
||||||
# Initialize own stats counter
|
# Initialize own stats counter
|
||||||
counter = await self.stat.new()
|
counter = stats.Stats(self.ns)
|
||||||
|
|
||||||
# Open remote server connection
|
# Open remote server connection
|
||||||
try:
|
try:
|
||||||
@ -184,8 +185,7 @@ class Proxy:
|
|||||||
|
|
||||||
logger.error('REMOTE from %s: %s', address, e)
|
logger.error('REMOTE from %s: %s', address, e)
|
||||||
finally:
|
finally:
|
||||||
await counter.close()
|
counter.close()
|
||||||
|
|
||||||
|
|
||||||
logger.info('CLOSED FROM %s', address)
|
logger.info('CLOSED FROM %s', address)
|
||||||
logger.info('STATS: %s', counter.as_csv())
|
|
||||||
|
@ -28,30 +28,30 @@
|
|||||||
'''
|
'''
|
||||||
@author: Adolfo Gómez, dkmaster at dkmon dot com
|
@author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||||
'''
|
'''
|
||||||
|
import multiprocessing
|
||||||
import time
|
import time
|
||||||
|
import logging
|
||||||
|
import typing
|
||||||
import io
|
import io
|
||||||
import ssl
|
import ssl
|
||||||
import logging
|
import logging
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
import curio
|
import curio
|
||||||
import blist
|
|
||||||
|
|
||||||
from . import config
|
from . import config
|
||||||
from . import consts
|
from . import consts
|
||||||
|
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
from multiprocessing.managers import Namespace, SyncManager
|
||||||
|
|
||||||
|
INTERVAL = 2 # Interval in seconds between stats update
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Locker for id assigner
|
|
||||||
assignLock = curio.Lock()
|
|
||||||
|
|
||||||
# Tuple index for several stats
|
|
||||||
SENT, RECV = 0, 1
|
|
||||||
|
|
||||||
# Subclasses for += operation to work
|
|
||||||
class StatsSingleCounter:
|
class StatsSingleCounter:
|
||||||
def __init__(self, parent: 'StatsConnection', for_receiving=True) -> None:
|
def __init__(self, parent: 'Stats', for_receiving=True) -> None:
|
||||||
if for_receiving:
|
if for_receiving:
|
||||||
self.adder = parent.add_recv
|
self.adder = parent.add_recv
|
||||||
else:
|
else:
|
||||||
@ -62,56 +62,34 @@ class StatsSingleCounter:
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
class StatsConnection:
|
class Stats:
|
||||||
id: int
|
ns: 'Namespace'
|
||||||
recv: int
|
|
||||||
sent: int
|
sent: int
|
||||||
start_time: int
|
recv: int
|
||||||
parent: 'Stats'
|
last: float
|
||||||
|
|
||||||
# Bandwidth stats (SENT, RECV)
|
def __init__(self, ns: 'Namespace'):
|
||||||
last: typing.List[int]
|
self.ns = ns
|
||||||
last_time: typing.List[float]
|
self.ns.current += 1
|
||||||
|
self.ns.total += 1
|
||||||
|
self.sent = 0
|
||||||
|
self.recv = 0
|
||||||
|
self.last = time.monotonic()
|
||||||
|
|
||||||
bandwidth: typing.List[int]
|
def update(self, force: bool = False):
|
||||||
max_bandwidth: typing.List[int]
|
now = time.monotonic()
|
||||||
|
if force or now - self.last > INTERVAL:
|
||||||
def __init__(self, parent: 'Stats', id: int) -> None:
|
self.last = now
|
||||||
self.id = id
|
self.ns.recv = self.recv
|
||||||
self.recv = self.sent = 0
|
self.ns.sent = self.sent
|
||||||
|
|
||||||
now = time.time()
|
|
||||||
self.start_time = int(now)
|
|
||||||
self.parent = parent
|
|
||||||
|
|
||||||
self.last = [0, 0]
|
|
||||||
self.last_time = [now, now]
|
|
||||||
self.bandwidth = [0, 0]
|
|
||||||
self.max_bandwidth = [0, 0]
|
|
||||||
|
|
||||||
def update_bandwidth(self, kind: int, counter: int):
|
|
||||||
now = time.time()
|
|
||||||
elapsed = now - self.last_time[kind]
|
|
||||||
# Update only when enouth data
|
|
||||||
if elapsed < consts.BANDWIDTH_TIME:
|
|
||||||
return
|
|
||||||
total = counter - self.last[kind]
|
|
||||||
self.bandwidth[kind] = int(float(total) / elapsed)
|
|
||||||
self.last[kind] = counter
|
|
||||||
self.last_time[kind] = now
|
|
||||||
|
|
||||||
if self.bandwidth[kind] > self.max_bandwidth[kind]:
|
|
||||||
self.max_bandwidth[kind] = self.bandwidth[kind]
|
|
||||||
|
|
||||||
def add_recv(self, size: int) -> None:
|
def add_recv(self, size: int) -> None:
|
||||||
self.recv += size
|
self.recv += size
|
||||||
self.update_bandwidth(RECV, counter=self.recv)
|
self.update()
|
||||||
self.parent.add_recv(size)
|
|
||||||
|
|
||||||
def add_sent(self, size: int) -> None:
|
def add_sent(self, size: int) -> None:
|
||||||
self.sent += size
|
self.sent += size
|
||||||
self.update_bandwidth(SENT, counter=self.sent)
|
self.update()
|
||||||
self.parent.add_sent(size)
|
|
||||||
|
|
||||||
def as_sent_counter(self) -> 'StatsSingleCounter':
|
def as_sent_counter(self) -> 'StatsSingleCounter':
|
||||||
return StatsSingleCounter(self, False)
|
return StatsSingleCounter(self, False)
|
||||||
@ -119,114 +97,34 @@ class StatsConnection:
|
|||||||
def as_recv_counter(self) -> 'StatsSingleCounter':
|
def as_recv_counter(self) -> 'StatsSingleCounter':
|
||||||
return StatsSingleCounter(self, True)
|
return StatsSingleCounter(self, True)
|
||||||
|
|
||||||
async def close(self) -> None:
|
def close(self):
|
||||||
if self.id:
|
self.update(True)
|
||||||
logger.debug(f'STAT {self.id} closed')
|
self.ns.current -= 1
|
||||||
await self.parent.remove(self.id)
|
|
||||||
self.id = 0
|
|
||||||
|
|
||||||
def as_csv(self, separator: typing.Optional[str] = None) -> str:
|
# Stats collector thread
|
||||||
separator = separator or ';'
|
class GlobalStats:
|
||||||
# With connections of less than a second, consider them as a second
|
manager: 'SyncManager'
|
||||||
elapsed = (int(time.time()) - self.start_time)
|
ns: 'Namespace'
|
||||||
|
counter: int
|
||||||
|
|
||||||
return separator.join(
|
def __init__(self):
|
||||||
str(i)
|
super().__init__()
|
||||||
for i in (
|
self.manager = multiprocessing.Manager()
|
||||||
self.id,
|
self.ns = self.manager.Namespace()
|
||||||
self.start_time,
|
|
||||||
elapsed,
|
|
||||||
self.sent,
|
|
||||||
self.bandwidth[SENT],
|
|
||||||
self.max_bandwidth[SENT],
|
|
||||||
self.recv,
|
|
||||||
self.bandwidth[RECV],
|
|
||||||
self.max_bandwidth[RECV],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
# Counters
|
||||||
return f'{self.id} t:{int(time.time())-self.start_time}, r:{self.recv}, s:{self.sent}>'
|
self.ns.current = 0
|
||||||
|
self.ns.total = 0
|
||||||
|
self.ns.sent = 0
|
||||||
|
self.ns.recv = 0
|
||||||
|
self.counter = 0
|
||||||
|
|
||||||
# For sorted array
|
def info(self) -> typing.Iterable[str]:
|
||||||
def __lt__(self, other) -> bool:
|
return GlobalStats.get_stats(self.ns)
|
||||||
if isinstance(other, int):
|
|
||||||
return self.id < other
|
|
||||||
|
|
||||||
if not isinstance(other, StatsConnection):
|
|
||||||
raise NotImplemented
|
|
||||||
|
|
||||||
return self.id < other.id
|
|
||||||
|
|
||||||
def __eq__(self, other) -> bool:
|
|
||||||
if isinstance(other, int):
|
|
||||||
return self.id == other
|
|
||||||
|
|
||||||
if not isinstance(other, StatsConnection):
|
|
||||||
raise NotImplemented
|
|
||||||
|
|
||||||
return self.id == other.id
|
|
||||||
|
|
||||||
|
|
||||||
class Stats:
|
|
||||||
counter_id: int
|
|
||||||
|
|
||||||
total_sent: int
|
|
||||||
total_received: int
|
|
||||||
current_connections: blist.sortedlist
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
# First connection will be 1
|
|
||||||
self.counter_id = 0
|
|
||||||
self.total_sent = self.total_received = 0
|
|
||||||
self.current_connections = blist.sortedlist()
|
|
||||||
|
|
||||||
async def new(self) -> StatsConnection:
|
|
||||||
"""Initializes a connection stats counter and returns it id
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: connection id
|
|
||||||
"""
|
|
||||||
async with assignLock:
|
|
||||||
self.counter_id += 1
|
|
||||||
connection = StatsConnection(self, self.counter_id)
|
|
||||||
self.current_connections.add(connection)
|
|
||||||
return connection
|
|
||||||
|
|
||||||
def add_sent(self, size: int) -> None:
|
|
||||||
self.total_sent += size
|
|
||||||
|
|
||||||
def add_recv(self, size: int) -> None:
|
|
||||||
self.total_received += size
|
|
||||||
|
|
||||||
async def remove(self, connection_id: int) -> None:
|
|
||||||
async with assignLock:
|
|
||||||
try:
|
|
||||||
self.current_connections.remove(connection_id)
|
|
||||||
except Exception:
|
|
||||||
logger.debug(
|
|
||||||
'Tried to remove %s from connections but was not present',
|
|
||||||
connection_id,
|
|
||||||
)
|
|
||||||
# Does not exists, ignore it
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def simple_as_csv(self, separator: typing.Optional[str] = None) -> typing.AsyncIterable[str]:
|
|
||||||
separator = separator or ';'
|
|
||||||
yield separator.join(
|
|
||||||
str(i)
|
|
||||||
for i in (
|
|
||||||
self.counter_id,
|
|
||||||
self.total_sent,
|
|
||||||
self.total_received,
|
|
||||||
len(self.current_connections),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def full_as_csv(self, separator: typing.Optional[str] = None) -> typing.AsyncIterable[str]:
|
|
||||||
for i in self.current_connections:
|
|
||||||
yield i.as_csv(separator)
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_stats(ns: 'Namespace') -> typing.Iterable[str]:
|
||||||
|
yield ';'.join([str(ns.current), str(ns.total), str(ns.sent), str(ns.recv)])
|
||||||
|
|
||||||
# Stats processor, invoked from command line
|
# Stats processor, invoked from command line
|
||||||
async def getServerStats(detailed: bool = False) -> None:
|
async def getServerStats(detailed: bool = False) -> None:
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
#
|
#
|
||||||
# Copyright (c) 2020 Virtual Cable S.L.U.
|
# Copyright (c) 2021 Virtual Cable S.L.U.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
#
|
#
|
||||||
# Redistribution and use in source and binary forms, with or without modification,
|
# Redistribution and use in source and binary forms, with or without modification,
|
||||||
@ -32,6 +32,7 @@
|
|||||||
import sys
|
import sys
|
||||||
import argparse
|
import argparse
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
import threading
|
||||||
import socket
|
import socket
|
||||||
import logging
|
import logging
|
||||||
import typing
|
import typing
|
||||||
@ -47,6 +48,7 @@ from uds_tunnel import stats
|
|||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
if typing.TYPE_CHECKING:
|
||||||
from multiprocessing.connection import Connection
|
from multiprocessing.connection import Connection
|
||||||
|
from multiprocessing.managers import Namespace
|
||||||
|
|
||||||
BACKLOG = 100
|
BACKLOG = 100
|
||||||
|
|
||||||
@ -76,7 +78,9 @@ def setup_log(cfg: config.ConfigurationType) -> None:
|
|||||||
log.addHandler(fileh)
|
log.addHandler(fileh)
|
||||||
|
|
||||||
|
|
||||||
async def tunnel_proc_async(pipe: 'Connection', cfg: config.ConfigurationType) -> None:
|
async def tunnel_proc_async(
|
||||||
|
pipe: 'Connection', cfg: config.ConfigurationType, ns: 'Namespace'
|
||||||
|
) -> None:
|
||||||
def get_socket(pipe: 'Connection') -> typing.Tuple[socket.SocketType, typing.Any]:
|
def get_socket(pipe: 'Connection') -> typing.Tuple[socket.SocketType, typing.Any]:
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
@ -92,7 +96,7 @@ async def tunnel_proc_async(pipe: 'Connection', cfg: config.ConfigurationType) -
|
|||||||
pipe: 'Connection', cfg: config.ConfigurationType, group: curio.TaskGroup
|
pipe: 'Connection', cfg: config.ConfigurationType, group: curio.TaskGroup
|
||||||
) -> None:
|
) -> None:
|
||||||
# Instantiate a proxy redirector for this process (we only need one per process!!)
|
# Instantiate a proxy redirector for this process (we only need one per process!!)
|
||||||
tunneler = proxy.Proxy(cfg)
|
tunneler = proxy.Proxy(cfg, ns)
|
||||||
|
|
||||||
# Generate SSL context
|
# Generate SSL context
|
||||||
context = curio.ssl.SSLContext(curio.ssl.PROTOCOL_TLS_SERVER)
|
context = curio.ssl.SSLContext(curio.ssl.PROTOCOL_TLS_SERVER)
|
||||||
@ -105,15 +109,18 @@ async def tunnel_proc_async(pipe: 'Connection', cfg: config.ConfigurationType) -
|
|||||||
context.load_dh_params(cfg.ssl_dhparam)
|
context.load_dh_params(cfg.ssl_dhparam)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
sock, address = await curio.run_in_thread(get_socket, pipe)
|
try:
|
||||||
if not sock:
|
sock, address = await curio.run_in_thread(get_socket, pipe)
|
||||||
break
|
if not sock:
|
||||||
logger.debug(
|
break
|
||||||
f'{multiprocessing.current_process().pid!r}: Got new connection from {address!r}'
|
logger.debug(
|
||||||
)
|
f'{multiprocessing.current_process().pid!r}: Got new connection from {address!r}'
|
||||||
sock = await context.wrap_socket(curio.io.Socket(sock), server_side=True)
|
)
|
||||||
await group.spawn(tunneler, sock, address)
|
sock = await context.wrap_socket(curio.io.Socket(sock), server_side=True)
|
||||||
del sock
|
await group.spawn(tunneler, sock, address)
|
||||||
|
del sock
|
||||||
|
except Exception as e:
|
||||||
|
logger.error('SETING UP CONNECTION: %s', e)
|
||||||
|
|
||||||
async with curio.TaskGroup() as tg:
|
async with curio.TaskGroup() as tg:
|
||||||
await tg.spawn(run_server, pipe, cfg, tg)
|
await tg.spawn(run_server, pipe, cfg, tg)
|
||||||
@ -133,10 +140,13 @@ def tunnel_main():
|
|||||||
typing.Tuple['Connection', multiprocessing.Process, psutil.Process]
|
typing.Tuple['Connection', multiprocessing.Process, psutil.Process]
|
||||||
] = []
|
] = []
|
||||||
|
|
||||||
|
stats_collector = stats.GlobalStats()
|
||||||
|
|
||||||
for i in range(cfg.workers):
|
for i in range(cfg.workers):
|
||||||
own_conn, child_conn = multiprocessing.Pipe()
|
own_conn, child_conn = multiprocessing.Pipe()
|
||||||
task = multiprocessing.Process(
|
task = multiprocessing.Process(
|
||||||
target=curio.run, args=(tunnel_proc_async, child_conn, cfg)
|
target=curio.run,
|
||||||
|
args=(tunnel_proc_async, child_conn, cfg, stats_collector.ns),
|
||||||
)
|
)
|
||||||
task.start()
|
task.start()
|
||||||
child.append((own_conn, task, psutil.Process(task.pid)))
|
child.append((own_conn, task, psutil.Process(task.pid)))
|
||||||
@ -166,10 +176,13 @@ def tunnel_main():
|
|||||||
client, addr = sock.accept()
|
client, addr = sock.accept()
|
||||||
# Select BEST process for sending this new connection
|
# Select BEST process for sending this new connection
|
||||||
best_child().send(message.Message(message.Command.TUNNEL, (client, addr)))
|
best_child().send(message.Message(message.Command.TUNNEL, (client, addr)))
|
||||||
|
del client # Ensure socket is controlled on child process
|
||||||
except Exception:
|
except Exception:
|
||||||
|
logger.exception('Mar')
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
logger.info('Exiting tunnel server')
|
||||||
|
|
||||||
if sock:
|
if sock:
|
||||||
sock.close()
|
sock.close()
|
||||||
|
|
||||||
@ -196,7 +209,12 @@ def main() -> None:
|
|||||||
|
|
||||||
if args.tunnel:
|
if args.tunnel:
|
||||||
tunnel_main()
|
tunnel_main()
|
||||||
parser.print_help()
|
elif args.detailed_stats:
|
||||||
|
curio.run(stats.getServerStats, True)
|
||||||
|
elif args.stats:
|
||||||
|
curio.run(stats.getServerStats, False)
|
||||||
|
else:
|
||||||
|
parser.print_help()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user