From 4d26df95804ae58fbfab88e91e4918337a5d3208 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Adolfo=20G=C3=B3mez=20Garc=C3=ADa?= <dkmaster@dkmon.com>
Date: Tue, 16 May 2023 01:09:32 +0200
Subject: [PATCH] upgrading and linting tunnel

---
 tunnel-server/src/uds_tunnel/consts.py |  4 ++
 tunnel-server/src/uds_tunnel/tunnel.py | 57 ++++++++------------------
 2 files changed, 22 insertions(+), 39 deletions(-)

diff --git a/tunnel-server/src/uds_tunnel/consts.py b/tunnel-server/src/uds_tunnel/consts.py
index e9210efa5..87fb0c0a6 100644
--- a/tunnel-server/src/uds_tunnel/consts.py
+++ b/tunnel-server/src/uds_tunnel/consts.py
@@ -28,6 +28,7 @@
 '''
 Author: Adolfo Gómez, dkmaster at dkmon dot com
 '''
+import re
 import typing
 
 DEBUG = True
@@ -71,3 +72,6 @@ RESPONSE_OK: typing.Final[bytes] = b'OK'
 
 # Backlog for listen socket
 BACKLOG = 1024
+
+# Regular expression for parsing ticket
+TICKET_REGEX = re.compile(f'^[a-zA-Z0-9]{{{TICKET_LENGTH}}}$')
diff --git a/tunnel-server/src/uds_tunnel/tunnel.py b/tunnel-server/src/uds_tunnel/tunnel.py
index 22f1d275c..209bd0098 100644
--- a/tunnel-server/src/uds_tunnel/tunnel.py
+++ b/tunnel-server/src/uds_tunnel/tunnel.py
@@ -78,9 +78,7 @@ class TunnelProtocol(asyncio.Protocol):
     # If there is a timeout task running
     timeout_task: typing.Optional[asyncio.Task] = None
 
-    def __init__(
-        self, owner: 'proxy.Proxy'
-    ) -> None:
+    def __init__(self, owner: 'proxy.Proxy') -> None:
         # If no other side is given, we are the server part
         super().__init__()
         # transport is undefined until connection_made is called
@@ -91,7 +89,7 @@ class TunnelProtocol(asyncio.Protocol):
         self.destination = ('', 0)
         self.tls_version = ''
         self.tls_cipher = ''
-        
+
         # If other_side is given, we are the client part (that is, the tunnel from us to remote machine)
         # In this case, only do_proxy is used
         self.client = None
@@ -124,9 +122,7 @@ class TunnelProtocol(asyncio.Protocol):
 
         async def open_client() -> None:
             try:
-                result = await TunnelProtocol.get_ticket_from_uds(
-                    self.owner.cfg, ticket, self.source
-                )
+                result = await TunnelProtocol.get_ticket_from_uds(self.owner.cfg, ticket, self.source)
             except Exception as e:
                 logger.error('ERROR %s', e.args[0] if e.args else e)
                 self.transport.write(consts.RESPONSE_ERROR_TICKET)
@@ -146,8 +142,7 @@ class TunnelProtocol(asyncio.Protocol):
             try:
                 family = (
                     socket.AF_INET6
-                    if ':' in self.destination[0]
-                    or (self.owner.cfg.ipv6 and not '.' in self.destination[0])
+                    if ':' in self.destination[0] or (self.owner.cfg.ipv6 and '.' not in self.destination[0])
                     else socket.AF_INET
                 )
                 (_, self.client) = await loop.create_connection(
@@ -161,7 +156,7 @@ class TunnelProtocol(asyncio.Protocol):
                 self.transport.resume_reading()
                 # send OK to client
                 self.transport.write(b'OK')
-                self.stats_manager.increment_connections() # Increment connections counters
+                self.stats_manager.increment_connections()  # Increment connections counters
             except Exception as e:
                 logger.error('Error opening connection: %s', e)
                 self.close_connection()
@@ -171,7 +166,7 @@ class TunnelProtocol(asyncio.Protocol):
         # From now, proxy connection
         self.runner = self.do_proxy
 
-    def process_stats(self, full: bool) -> None:
+    def process_stats(self, full: bool) -> None:  # pylint: disable=unused-argument
         # if pasword is not already received, wait for it
         if len(self.cmd) < consts.PASSWORD_LENGTH + consts.COMMAND_LENGTH:
             return
@@ -246,22 +241,22 @@ class TunnelProtocol(asyncio.Protocol):
             try:
                 if command == consts.COMMAND_OPEN:
                     self.process_open()
-                elif command == consts.COMMAND_TEST:
+                    return
+                if command == consts.COMMAND_TEST:
                     self.clean_timeout()  # Stop timeout
                     logger.info('COMMAND: TEST')
                     self.transport.write(consts.RESPONSE_OK)
                     self.close_connection()
                     return
-                elif command in (consts.COMMAND_STAT, consts.COMMAND_INFO):
+                if command in (consts.COMMAND_STAT, consts.COMMAND_INFO):
                     # This is an stats requests
                     try:
-                        self.process_stats(full=command == consts.COMMAND_STAT) 
+                        self.process_stats(full=command == consts.COMMAND_STAT)
                     except Exception as e:
                         logger.error('ERROR processing stats: %s', e.args[0] if e.args else e)
                     self.close_connection()
                     return
-                else:
-                    raise Exception('Invalid command')
+                raise Exception('Invalid command')
             except Exception:
                 logger.error('ERROR from %s', self.pretty_source())
                 self.transport.write(consts.RESPONSE_ERROR_COMMAND)
@@ -298,9 +293,7 @@ class TunnelProtocol(asyncio.Protocol):
             )
             # Notify end to uds, using a task becase we are not an async function
             asyncio.get_event_loop().create_task(
-                TunnelProtocol.notify_end_to_uds(
-                    self.owner.cfg, self.notify_ticket, self.stats_manager
-                )
+                TunnelProtocol.notify_end_to_uds(self.owner.cfg, self.notify_ticket, self.stats_manager)
             )
             self.notify_ticket = b''  # Clean up so no more notifications
         else:
@@ -350,7 +343,6 @@ class TunnelProtocol(asyncio.Protocol):
     def pretty_destination(self) -> str:
         return TunnelProtocol.pretty_address(self.destination)
 
-
     @staticmethod
     async def _read_from_uds(
         cfg: config.ConfigurationType,
@@ -359,13 +351,9 @@ class TunnelProtocol(asyncio.Protocol):
         queryParams: typing.Optional[typing.Mapping[str, str]] = None,
     ) -> typing.MutableMapping[str, typing.Any]:
         try:
-            url = (
-                cfg.uds_server + '/' + ticket.decode() + '/' + msg + '/' + cfg.uds_token
-            )
+            url = cfg.uds_server + '/' + ticket.decode() + '/' + msg + '/' + cfg.uds_token
             if queryParams:
-                url += '?' + '&'.join(
-                    [f'{key}={value}' for key, value in queryParams.items()]
-                )
+                url += '?' + '&'.join([f'{key}={value}' for key, value in queryParams.items()])
             # Set options
             options: typing.Dict[str, typing.Any] = {'timeout': cfg.uds_timeout}
             if cfg.uds_verify_ssl is False:
@@ -378,24 +366,15 @@ class TunnelProtocol(asyncio.Protocol):
                         raise Exception(await r.text())
                     return await r.json()
         except Exception as e:
-            raise Exception(f'TICKET COMMS ERROR: {ticket.decode()} {msg} {e!s}')
+            raise Exception(f'TICKET COMMS ERROR: {ticket.decode()} {msg} {e!s}') from e
 
     @staticmethod
     async def get_ticket_from_uds(
         cfg: config.ConfigurationType, ticket: bytes, address: typing.Tuple[str, int]
     ) -> typing.MutableMapping[str, typing.Any]:
-        # Sanity checks
-        if len(ticket) != consts.TICKET_LENGTH:
-            raise ValueError(f'TICKET INVALID (len={len(ticket)})')
-
-        for n, i in enumerate(ticket.decode(errors='ignore')):
-            if (
-                (i >= 'a' and i <= 'z')
-                or (i >= '0' and i <= '9')
-                or (i >= 'A' and i <= 'Z')
-            ):
-                continue  # Correctus
-            raise ValueError(f'TICKET INVALID (char {i} at pos {n})')
+        # Check ticket using re
+        if consts.TICKET_REGEX.match(ticket.decode(errors='replace')) is None:
+            raise ValueError(f'TICKET INVALID ({ticket.decode(errors="replace")})')
 
         return await TunnelProtocol._read_from_uds(cfg, ticket, address[0])