1
0
mirror of https://github.com/dkmstr/openuds.git synced 2024-12-22 13:34:04 +03:00
openuds/tunnel-server/samples/async_upgrade_server.py
2022-12-22 15:25:45 +01:00

137 lines
4.5 KiB
Python
Executable File

#!/usr/bin/env python3
import asyncio
import ssl
import logging
import socket
import typing
if typing.TYPE_CHECKING:
import asyncio.transports
logger = logging.getLogger(__name__)
BACKLOG = 100
STATE_UNINITIALIZED = 0
STATE_COMMAND = 1
STATE_PROXY = 2
COMMAND_LENGTH = 4
TICKET_LENGTH = 48
# Protocol
class TunnelProtocol(asyncio.Protocol):
# future to mark eof
finished: asyncio.Future
transport: 'asyncio.transports.Transport'
other_side: 'TunnelProtocol'
state: int
cmd: bytes
def __init__(self, other_side: typing.Optional['TunnelProtocol'] = None) -> None:
# If no other side is given, we are the server part
super().__init__()
self.state = STATE_UNINITIALIZED
self.other_side = (
other_side if other_side else self
) # No other side, self is just a placeholder
# transport is undefined until conne
self.finished = asyncio.Future()
self.cmd = b''
def connection_made(self, transport: 'asyncio.transports.BaseTransport') -> None:
logger.debug('Connection made: %s', transport.get_extra_info('peername'))
# Update state based on if we are the client or server
self.state = STATE_COMMAND if self.other_side is self else STATE_PROXY
# We know for sure that the transport is a Transport.
self.transport = typing.cast('asyncio.transports.Transport', transport)
self.cmd = b''
def process_command(self, data: bytes) -> None:
self.cmd += data
if len(self.cmd) >= COMMAND_LENGTH:
logger.debug('Command received: %s', self.cmd)
if self.cmd[:4] == b'OPEN':
# Open Command has the ticket behind it
if len(self.cmd) < TICKET_LENGTH + COMMAND_LENGTH:
return # Wait for more data
# log the ticket
logger.debug('Ticket received: %s', self.cmd[4:4+TICKET_LENGTH])
loop = asyncio.get_event_loop()
async def open_other_side() -> None:
try:
(transport, protocol) = await loop.create_connection(
lambda: TunnelProtocol(self), 'www.google.com', 80
)
self.other_side = typing.cast('TunnelProtocol', protocol)
self.other_side.transport.write(
b'GET / HTTP/1.0\r\nHost: www.google.com\r\n\r\n'
)
except Exception as e:
logger.error('Error opening connection: %s', e)
# Send error to client
self.transport.write(b'ERR')
self.transport.close()
loop.create_task(open_other_side())
self.state = STATE_PROXY
def process_proxy(self, data: bytes) -> None:
logger.debug('Processing proxy: %s', len(data))
self.other_side.transport.write(data)
def data_received(self, data: bytes):
logger.debug('Data received: %s', len(data))
if self.state == STATE_COMMAND:
self.process_command(data)
elif self.state == STATE_PROXY:
self.process_proxy(data)
else:
logger.error('Invalid state reached!')
def connection_lost(self, exc: typing.Optional[Exception]) -> None:
self.finished.set_result(True)
if self.other_side is not self:
self.other_side.transport.close()
async def main():
# Init logger
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
)
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
context.load_cert_chain('tests/testing.pem', 'tests/testing.key')
# Create a server
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
sock.settimeout(444.0) # So we can check for stop from time to time
sock.bind(('0.0.0.0', 7777))
sock.listen(BACKLOG)
loop = asyncio.get_running_loop()
# Accepts connections
client, addr = sock.accept()
logger.debug('Accepted connection')
data = client.recv(4)
print(data)
# Upgrade connection to SSL, and use asyncio to handle the rest
(_, protocol) = await loop.connect_accepted_socket( # type: ignore
lambda: TunnelProtocol(), client, ssl=context
)
await protocol.finished
if __name__ == '__main__':
asyncio.run(main())