forked from shaba/openuds
Fixed 3.5 tunnel DOS attacks tolerance
This commit is contained in:
parent
d2ef6e3704
commit
b3047e366d
@ -50,11 +50,11 @@ class Processes:
|
||||
for i, c in enumerate(self.children):
|
||||
try:
|
||||
if c[2].status() == 'zombie': # Bad kill!!
|
||||
raise psutil.ZombieProcess(c[2].pid)
|
||||
raise psutil.ZombieProcess(c[2])
|
||||
percent = c[2].cpu_percent()
|
||||
except (psutil.ZombieProcess, psutil.NoSuchProcess) as e:
|
||||
# Process is missing...
|
||||
logger.warning('Missing process found: %s', e.pid)
|
||||
logger.warning('Missing process found: %s', e)
|
||||
try:
|
||||
c[0].close() # Close pipe to missing process
|
||||
except Exception:
|
||||
|
@ -134,7 +134,7 @@ class Proxy:
|
||||
try:
|
||||
await self.proxy(source, address)
|
||||
except Exception as e:
|
||||
logger.error('Error procesing connection from %s: %s', address, e)
|
||||
logger.exception('Error procesing connection from %s: %s', address, e)
|
||||
|
||||
async def stats(self, full: bool, source, address: typing.Tuple[str, int]) -> None:
|
||||
# Check valid source ip
|
||||
@ -162,8 +162,9 @@ class Proxy:
|
||||
logger.info('CONNECT FROM %s', prettySource)
|
||||
|
||||
# Handshake correct in this point, start SSL connection
|
||||
command: bytes = b''
|
||||
try:
|
||||
command: bytes = await source.recv(consts.COMMAND_LENGTH)
|
||||
command = await source.recv(consts.COMMAND_LENGTH)
|
||||
if command == consts.COMMAND_TEST:
|
||||
logger.info('COMMAND: TEST')
|
||||
await source.sendall(b'OK')
|
||||
@ -181,7 +182,7 @@ class Proxy:
|
||||
|
||||
if command != consts.COMMAND_OPEN:
|
||||
# Invalid command
|
||||
raise Exception()
|
||||
raise Exception(command)
|
||||
|
||||
# Now, read a TICKET_LENGTH (64) bytes string, that must be [a-zA-Z0-9]{64}
|
||||
ticket: bytes = await source.recv(consts.TICKET_LENGTH)
|
||||
@ -193,27 +194,32 @@ class Proxy:
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error('ERROR %s', e.args[0] if e.args else e)
|
||||
await source.sendall(b'ERROR_TICKET')
|
||||
try:
|
||||
await source.sendall(b'ERROR_TICKET')
|
||||
except Exception:
|
||||
pass # Ignore errors
|
||||
return
|
||||
|
||||
prettyDest = f"{result['host']}:{result['port']}"
|
||||
logger.info('OPEN TUNNEL FROM %s to %s', prettySource, prettyDest)
|
||||
|
||||
except Exception:
|
||||
if consts.DEBUG:
|
||||
logger.exception('COMMAND')
|
||||
logger.error('ERROR from %s', prettySource)
|
||||
await source.sendall(b'ERROR_COMMAND')
|
||||
logger.error('ERROR from %s: COMMAND %s', prettySource, command)
|
||||
try:
|
||||
await source.sendall(b'ERROR_COMMAND')
|
||||
except Exception:
|
||||
pass # Ignore errors
|
||||
return
|
||||
|
||||
# Communicate source OPEN is ok
|
||||
await source.sendall(b'OK')
|
||||
|
||||
# Initialize own stats counter
|
||||
counter = stats.Stats(self.ns)
|
||||
|
||||
# Open remote server connection
|
||||
try:
|
||||
# Communicate source OPEN is ok
|
||||
await source.sendall(b'OK')
|
||||
|
||||
# Open remote server connection
|
||||
destination = await curio.open_connection(
|
||||
result['host'], int(result['port'])
|
||||
)
|
||||
|
@ -40,7 +40,8 @@ import logging
|
||||
import typing
|
||||
|
||||
import curio
|
||||
import psutil
|
||||
import curio.io
|
||||
import curio.errors
|
||||
import setproctitle
|
||||
|
||||
from uds_tunnel import config
|
||||
@ -102,34 +103,16 @@ async def tunnel_proc_async(
|
||||
pipe: 'Connection', cfg: config.ConfigurationType, ns: 'Namespace'
|
||||
) -> None:
|
||||
def get_socket(pipe: 'Connection') -> typing.Tuple[typing.Optional[socket.SocketType], typing.Any]:
|
||||
try:
|
||||
data: bytes = b''
|
||||
while True:
|
||||
msg: message.Message = pipe.recv()
|
||||
if msg.command == message.Command.TUNNEL and msg.connection:
|
||||
# Connection done, check for handshake
|
||||
source, address = msg.connection
|
||||
try:
|
||||
msg: message.Message = pipe.recv()
|
||||
if msg.command == message.Command.TUNNEL and msg.connection:
|
||||
return msg.connection
|
||||
|
||||
try:
|
||||
# First, ensure handshake (simple handshake) and command
|
||||
data = source.recv(len(consts.HANDSHAKE_V1))
|
||||
|
||||
if data != consts.HANDSHAKE_V1:
|
||||
raise Exception() # Invalid handshake
|
||||
except Exception:
|
||||
if consts.DEBUG:
|
||||
logger.exception('HANDSHAKE')
|
||||
logger.error('HANDSHAKE from %s (%s)', address, data.hex())
|
||||
# Close Source and continue
|
||||
source.close()
|
||||
continue
|
||||
|
||||
return msg.connection
|
||||
|
||||
# Process other messages, and retry
|
||||
except Exception:
|
||||
logger.exception('Receiving data from parent process')
|
||||
return None, None
|
||||
# Process other messages, and retry
|
||||
except Exception:
|
||||
logger.exception('Receiving data from parent process')
|
||||
return None, None
|
||||
|
||||
async def run_server(
|
||||
pipe: 'Connection', cfg: config.ConfigurationType, group: curio.TaskGroup
|
||||
@ -147,20 +130,37 @@ async def tunnel_proc_async(
|
||||
if cfg.ssl_dhparam:
|
||||
context.load_dh_params(cfg.ssl_dhparam)
|
||||
|
||||
async def processSocket(ssock: socket.socket) -> None:
|
||||
sock = curio.io.Socket(ssock)
|
||||
try:
|
||||
# First, ensure handshake (simple handshake) and command
|
||||
async with curio.timeout_after(3): # type: ignore
|
||||
data = await sock.recv(len(consts.HANDSHAKE_V1))
|
||||
|
||||
if data != consts.HANDSHAKE_V1:
|
||||
raise Exception(data) # Invalid handshake
|
||||
except (curio.errors.CancelledError, Exception) as e:
|
||||
logger.error('HANDSHAKE from %s (%s)', address, 'timeout' if isinstance(e, curio.errors.CancelledError) else e)
|
||||
# Close Source and continue
|
||||
await sock.close()
|
||||
return
|
||||
sslsock = await context.wrap_socket(
|
||||
sock, server_side=True # type: ignore
|
||||
)
|
||||
await group.spawn(tunneler, sslsock, address)
|
||||
del sslsock
|
||||
|
||||
|
||||
while True:
|
||||
address = ('', '')
|
||||
try:
|
||||
sock, address = await curio.run_in_thread(get_socket, pipe)
|
||||
if not sock:
|
||||
ssock, address = await curio.run_in_thread(get_socket, pipe)
|
||||
if not ssock:
|
||||
break
|
||||
logger.debug(
|
||||
f'CONNECTION from {address!r} (pid: {os.getpid()})'
|
||||
)
|
||||
sock = await context.wrap_socket(
|
||||
curio.io.Socket(sock), server_side=True # type: ignore
|
||||
)
|
||||
await group.spawn(tunneler, sock, address)
|
||||
del sock
|
||||
await group.spawn(processSocket, ssock)
|
||||
except Exception:
|
||||
logger.error('NEGOTIATION ERROR from %s', address[0])
|
||||
|
||||
@ -173,6 +173,8 @@ async def tunnel_proc_async(
|
||||
# task.joined = True
|
||||
# del task
|
||||
|
||||
logger.info('PROCESS %s stopped', os.getpid())
|
||||
|
||||
|
||||
def tunnel_main():
|
||||
cfg = config.read()
|
||||
@ -230,7 +232,7 @@ def tunnel_main():
|
||||
client, addr = sock.accept()
|
||||
client.settimeout(3.0)
|
||||
|
||||
logger.debug('CONNECTION from %s', addr)
|
||||
logger.info('CONNECTION from %s', addr[0])
|
||||
# Select BEST process for sending this new connection
|
||||
prcs.best_child().send(
|
||||
message.Message(message.Command.TUNNEL, (client, addr))
|
||||
|
Loading…
x
Reference in New Issue
Block a user