Advancing on new tunneler

This commit is contained in:
Adolfo Gómez García 2021-01-13 10:04:26 +01:00
parent e486d6708d
commit 971e5984d9
4 changed files with 102 additions and 185 deletions

View File

@ -143,6 +143,7 @@ class Handler(socketserver.BaseRequestHandler):
if not data: if not data:
break break
self.request.sendall(data) self.request.sendall(data)
logger.debug('Finished process')
except Exception as e: except Exception as e:
pass pass
@ -173,8 +174,8 @@ def forward(
if __name__ == "__main__": if __name__ == "__main__":
fs1 = forward(('fake.udsenterprise.com', 7777), '0'*64, local_port=49998) fs1 = forward(('fake.udsenterprise.com', 7777), '0'*64, local_port=49998)
print(f'Listening on {fs1.server_address}') print(f'Listening on {fs1.server_address}')
#fs2 = forward(('fake.udsenterprise.com', 7777), '1'*64, local_port=49999) fs2 = forward(('fake.udsenterprise.com', 7777), '1'*64, local_port=49999)
#print(f'Listening on {fs2.server_address}') print(f'Listening on {fs2.server_address}')
# time.sleep(30) # time.sleep(30)
# fs.stop() # fs.stop()

View File

@ -38,15 +38,18 @@ from . import config
from . import stats from . import stats
from . import consts from . import consts
if typing.TYPE_CHECKING:
from multiprocessing.managers import Namespace
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Proxy: class Proxy:
cfg: config.ConfigurationType cfg: config.ConfigurationType
stat: stats.Stats ns: 'Namespace'
def __init__(self, cfg: config.ConfigurationType) -> None: def __init__(self, cfg: config.ConfigurationType, ns: 'Namespace') -> None:
self.cfg = cfg self.cfg = cfg
self.stat = stats.Stats() self.ns = ns
@staticmethod @staticmethod
def getFromUds(cfg: config.ConfigurationType, ticket: bytes) -> typing.MutableMapping[str, typing.Any]: def getFromUds(cfg: config.ConfigurationType, ticket: bytes) -> typing.MutableMapping[str, typing.Any]:
@ -77,6 +80,10 @@ class Proxy:
await destination.sendall(data) await destination.sendall(data)
counter.add(len(data)) counter.add(len(data))
# Method responsible of proxying requests
async def __call__(self, source, address: typing.Tuple[str, int]) -> None:
await self.proxy(source, address)
async def stats(self, full: bool, source, address: typing.Tuple[str, int]) -> None: async def stats(self, full: bool, source, address: typing.Tuple[str, int]) -> None:
# Check valid source ip # Check valid source ip
if address[0] not in self.cfg.allow: if address[0] not in self.cfg.allow:
@ -93,17 +100,12 @@ class Proxy:
logger.info('STATS TO %s', address) logger.info('STATS TO %s', address)
if full: data = stats.GlobalStats.get_stats(self.ns)
data = self.stat.full_as_csv()
else:
data = self.stat.simple_as_csv()
async for v in data: for v in data:
logger.debug('SENDING %s', v)
await source.sendall(v.encode() + b'\n') await source.sendall(v.encode() + b'\n')
# Method responsible of proxying requests
async def __call__(self, source, address: typing.Tuple[str, int]) -> None:
await self.proxy(source, address)
async def proxy(self, source, address: typing.Tuple[str, int]) -> None: async def proxy(self, source, address: typing.Tuple[str, int]) -> None:
logger.info('OPEN FROM %s', address) logger.info('OPEN FROM %s', address)
@ -119,7 +121,6 @@ class Proxy:
logger.exception('HANDSHAKE') logger.exception('HANDSHAKE')
logger.error('HANDSHAKE from %s', address) logger.error('HANDSHAKE from %s', address)
await source.sendall(b'HANDSHAKE_ERROR') await source.sendall(b'HANDSHAKE_ERROR')
# Closes connection now # Closes connection now
return return
@ -166,7 +167,7 @@ class Proxy:
await source.sendall(b'OK') await source.sendall(b'OK')
# Initialize own stats counter # Initialize own stats counter
counter = await self.stat.new() counter = stats.Stats(self.ns)
# Open remote server connection # Open remote server connection
try: try:
@ -184,8 +185,7 @@ class Proxy:
logger.error('REMOTE from %s: %s', address, e) logger.error('REMOTE from %s: %s', address, e)
finally: finally:
await counter.close() counter.close()
logger.info('CLOSED FROM %s', address) logger.info('CLOSED FROM %s', address)
logger.info('STATS: %s', counter.as_csv())

View File

@ -28,30 +28,30 @@
''' '''
@author: Adolfo Gómez, dkmaster at dkmon dot com @author: Adolfo Gómez, dkmaster at dkmon dot com
''' '''
import multiprocessing
import time import time
import logging
import typing
import io import io
import ssl import ssl
import logging import logging
import typing import typing
import curio import curio
import blist
from . import config from . import config
from . import consts from . import consts
if typing.TYPE_CHECKING:
from multiprocessing.managers import Namespace, SyncManager
INTERVAL = 2 # Interval in seconds between stats update
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Locker for id assigner
assignLock = curio.Lock()
# Tuple index for several stats
SENT, RECV = 0, 1
# Subclasses for += operation to work
class StatsSingleCounter: class StatsSingleCounter:
def __init__(self, parent: 'StatsConnection', for_receiving=True) -> None: def __init__(self, parent: 'Stats', for_receiving=True) -> None:
if for_receiving: if for_receiving:
self.adder = parent.add_recv self.adder = parent.add_recv
else: else:
@ -62,56 +62,34 @@ class StatsSingleCounter:
return self return self
class StatsConnection: class Stats:
id: int ns: 'Namespace'
recv: int
sent: int sent: int
start_time: int recv: int
parent: 'Stats' last: float
# Bandwidth stats (SENT, RECV) def __init__(self, ns: 'Namespace'):
last: typing.List[int] self.ns = ns
last_time: typing.List[float] self.ns.current += 1
self.ns.total += 1
self.sent = 0
self.recv = 0
self.last = time.monotonic()
bandwidth: typing.List[int] def update(self, force: bool = False):
max_bandwidth: typing.List[int] now = time.monotonic()
if force or now - self.last > INTERVAL:
def __init__(self, parent: 'Stats', id: int) -> None: self.last = now
self.id = id self.ns.recv = self.recv
self.recv = self.sent = 0 self.ns.sent = self.sent
now = time.time()
self.start_time = int(now)
self.parent = parent
self.last = [0, 0]
self.last_time = [now, now]
self.bandwidth = [0, 0]
self.max_bandwidth = [0, 0]
def update_bandwidth(self, kind: int, counter: int):
now = time.time()
elapsed = now - self.last_time[kind]
# Update only when enouth data
if elapsed < consts.BANDWIDTH_TIME:
return
total = counter - self.last[kind]
self.bandwidth[kind] = int(float(total) / elapsed)
self.last[kind] = counter
self.last_time[kind] = now
if self.bandwidth[kind] > self.max_bandwidth[kind]:
self.max_bandwidth[kind] = self.bandwidth[kind]
def add_recv(self, size: int) -> None: def add_recv(self, size: int) -> None:
self.recv += size self.recv += size
self.update_bandwidth(RECV, counter=self.recv) self.update()
self.parent.add_recv(size)
def add_sent(self, size: int) -> None: def add_sent(self, size: int) -> None:
self.sent += size self.sent += size
self.update_bandwidth(SENT, counter=self.sent) self.update()
self.parent.add_sent(size)
def as_sent_counter(self) -> 'StatsSingleCounter': def as_sent_counter(self) -> 'StatsSingleCounter':
return StatsSingleCounter(self, False) return StatsSingleCounter(self, False)
@ -119,114 +97,34 @@ class StatsConnection:
def as_recv_counter(self) -> 'StatsSingleCounter': def as_recv_counter(self) -> 'StatsSingleCounter':
return StatsSingleCounter(self, True) return StatsSingleCounter(self, True)
async def close(self) -> None: def close(self):
if self.id: self.update(True)
logger.debug(f'STAT {self.id} closed') self.ns.current -= 1
await self.parent.remove(self.id)
self.id = 0
def as_csv(self, separator: typing.Optional[str] = None) -> str: # Stats collector thread
separator = separator or ';' class GlobalStats:
# With connections of less than a second, consider them as a second manager: 'SyncManager'
elapsed = (int(time.time()) - self.start_time) ns: 'Namespace'
counter: int
return separator.join( def __init__(self):
str(i) super().__init__()
for i in ( self.manager = multiprocessing.Manager()
self.id, self.ns = self.manager.Namespace()
self.start_time,
elapsed,
self.sent,
self.bandwidth[SENT],
self.max_bandwidth[SENT],
self.recv,
self.bandwidth[RECV],
self.max_bandwidth[RECV],
)
)
def __str__(self) -> str: # Counters
return f'{self.id} t:{int(time.time())-self.start_time}, r:{self.recv}, s:{self.sent}>' self.ns.current = 0
self.ns.total = 0
self.ns.sent = 0
self.ns.recv = 0
self.counter = 0
# For sorted array def info(self) -> typing.Iterable[str]:
def __lt__(self, other) -> bool: return GlobalStats.get_stats(self.ns)
if isinstance(other, int):
return self.id < other
if not isinstance(other, StatsConnection):
raise NotImplemented
return self.id < other.id
def __eq__(self, other) -> bool:
if isinstance(other, int):
return self.id == other
if not isinstance(other, StatsConnection):
raise NotImplemented
return self.id == other.id
class Stats:
counter_id: int
total_sent: int
total_received: int
current_connections: blist.sortedlist
def __init__(self) -> None:
# First connection will be 1
self.counter_id = 0
self.total_sent = self.total_received = 0
self.current_connections = blist.sortedlist()
async def new(self) -> StatsConnection:
"""Initializes a connection stats counter and returns it id
Returns:
str: connection id
"""
async with assignLock:
self.counter_id += 1
connection = StatsConnection(self, self.counter_id)
self.current_connections.add(connection)
return connection
def add_sent(self, size: int) -> None:
self.total_sent += size
def add_recv(self, size: int) -> None:
self.total_received += size
async def remove(self, connection_id: int) -> None:
async with assignLock:
try:
self.current_connections.remove(connection_id)
except Exception:
logger.debug(
'Tried to remove %s from connections but was not present',
connection_id,
)
# Does not exists, ignore it
pass
async def simple_as_csv(self, separator: typing.Optional[str] = None) -> typing.AsyncIterable[str]:
separator = separator or ';'
yield separator.join(
str(i)
for i in (
self.counter_id,
self.total_sent,
self.total_received,
len(self.current_connections),
)
)
async def full_as_csv(self, separator: typing.Optional[str] = None) -> typing.AsyncIterable[str]:
for i in self.current_connections:
yield i.as_csv(separator)
@staticmethod
def get_stats(ns: 'Namespace') -> typing.Iterable[str]:
yield ';'.join([str(ns.current), str(ns.total), str(ns.sent), str(ns.recv)])
# Stats processor, invoked from command line # Stats processor, invoked from command line
async def getServerStats(detailed: bool = False) -> None: async def getServerStats(detailed: bool = False) -> None:

View File

@ -1,7 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
# Copyright (c) 2020 Virtual Cable S.L.U. # Copyright (c) 2021 Virtual Cable S.L.U.
# All rights reserved. # All rights reserved.
# #
# Redistribution and use in source and binary forms, with or without modification, # Redistribution and use in source and binary forms, with or without modification,
@ -32,6 +32,7 @@
import sys import sys
import argparse import argparse
import multiprocessing import multiprocessing
import threading
import socket import socket
import logging import logging
import typing import typing
@ -47,6 +48,7 @@ from uds_tunnel import stats
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from multiprocessing.connection import Connection from multiprocessing.connection import Connection
from multiprocessing.managers import Namespace
BACKLOG = 100 BACKLOG = 100
@ -76,7 +78,9 @@ def setup_log(cfg: config.ConfigurationType) -> None:
log.addHandler(fileh) log.addHandler(fileh)
async def tunnel_proc_async(pipe: 'Connection', cfg: config.ConfigurationType) -> None: async def tunnel_proc_async(
pipe: 'Connection', cfg: config.ConfigurationType, ns: 'Namespace'
) -> None:
def get_socket(pipe: 'Connection') -> typing.Tuple[socket.SocketType, typing.Any]: def get_socket(pipe: 'Connection') -> typing.Tuple[socket.SocketType, typing.Any]:
try: try:
while True: while True:
@ -92,7 +96,7 @@ async def tunnel_proc_async(pipe: 'Connection', cfg: config.ConfigurationType) -
pipe: 'Connection', cfg: config.ConfigurationType, group: curio.TaskGroup pipe: 'Connection', cfg: config.ConfigurationType, group: curio.TaskGroup
) -> None: ) -> None:
# Instantiate a proxy redirector for this process (we only need one per process!!) # Instantiate a proxy redirector for this process (we only need one per process!!)
tunneler = proxy.Proxy(cfg) tunneler = proxy.Proxy(cfg, ns)
# Generate SSL context # Generate SSL context
context = curio.ssl.SSLContext(curio.ssl.PROTOCOL_TLS_SERVER) context = curio.ssl.SSLContext(curio.ssl.PROTOCOL_TLS_SERVER)
@ -105,15 +109,18 @@ async def tunnel_proc_async(pipe: 'Connection', cfg: config.ConfigurationType) -
context.load_dh_params(cfg.ssl_dhparam) context.load_dh_params(cfg.ssl_dhparam)
while True: while True:
sock, address = await curio.run_in_thread(get_socket, pipe) try:
if not sock: sock, address = await curio.run_in_thread(get_socket, pipe)
break if not sock:
logger.debug( break
f'{multiprocessing.current_process().pid!r}: Got new connection from {address!r}' logger.debug(
) f'{multiprocessing.current_process().pid!r}: Got new connection from {address!r}'
sock = await context.wrap_socket(curio.io.Socket(sock), server_side=True) )
await group.spawn(tunneler, sock, address) sock = await context.wrap_socket(curio.io.Socket(sock), server_side=True)
del sock await group.spawn(tunneler, sock, address)
del sock
except Exception as e:
logger.error('SETING UP CONNECTION: %s', e)
async with curio.TaskGroup() as tg: async with curio.TaskGroup() as tg:
await tg.spawn(run_server, pipe, cfg, tg) await tg.spawn(run_server, pipe, cfg, tg)
@ -133,10 +140,13 @@ def tunnel_main():
typing.Tuple['Connection', multiprocessing.Process, psutil.Process] typing.Tuple['Connection', multiprocessing.Process, psutil.Process]
] = [] ] = []
stats_collector = stats.GlobalStats()
for i in range(cfg.workers): for i in range(cfg.workers):
own_conn, child_conn = multiprocessing.Pipe() own_conn, child_conn = multiprocessing.Pipe()
task = multiprocessing.Process( task = multiprocessing.Process(
target=curio.run, args=(tunnel_proc_async, child_conn, cfg) target=curio.run,
args=(tunnel_proc_async, child_conn, cfg, stats_collector.ns),
) )
task.start() task.start()
child.append((own_conn, task, psutil.Process(task.pid))) child.append((own_conn, task, psutil.Process(task.pid)))
@ -166,10 +176,13 @@ def tunnel_main():
client, addr = sock.accept() client, addr = sock.accept()
# Select BEST process for sending this new connection # Select BEST process for sending this new connection
best_child().send(message.Message(message.Command.TUNNEL, (client, addr))) best_child().send(message.Message(message.Command.TUNNEL, (client, addr)))
del client # Ensure socket is controlled on child process
except Exception: except Exception:
logger.exception('Mar')
pass pass
logger.info('Exiting tunnel server')
if sock: if sock:
sock.close() sock.close()
@ -196,7 +209,12 @@ def main() -> None:
if args.tunnel: if args.tunnel:
tunnel_main() tunnel_main()
parser.print_help() elif args.detailed_stats:
curio.run(stats.getServerStats, True)
elif args.stats:
curio.run(stats.getServerStats, False)
else:
parser.print_help()
if __name__ == "__main__": if __name__ == "__main__":