Adding osDetector to UDSClient

This commit is contained in:
Adolfo Gómez García 2021-05-04 13:05:53 +02:00
parent 98293bba75
commit 3f6d12c89f
5 changed files with 235 additions and 293 deletions

View File

@ -27,17 +27,15 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
'''
@author: Adolfo Gómez, dkmaster at dkmon dot com
"""
'''
import sys
import webbrowser
import json
import base64
import bz2
from PyQt5 import QtCore, QtWidgets
import base64, bz2
from PyQt5 import QtCore, QtGui, QtWidgets # @UnresolvedImport
import six
from uds.rest import RestRequest
@ -50,13 +48,11 @@ from uds import VERSION
from UDSWindow import Ui_MainWindow
# Server before this version uses "unsigned" scripts
OLD_METHOD_VERSION = "2.4.0"
OLD_METHOD_VERSION = '2.4.0'
class RetryException(Exception):
pass
class UDSClient(QtWidgets.QMainWindow):
ticket = None
@ -65,14 +61,12 @@ class UDSClient(QtWidgets.QMainWindow):
animTimer = None
anim = 0
animInverted = False
serverVersion = "X.Y.Z" # Will be overwriten on getVersion
serverVersion = 'X.Y.Z' # Will be overwriten on getVersion
req = None
def __init__(self):
QtWidgets.QMainWindow.__init__(self)
self.setWindowFlags(
QtCore.Qt.FramelessWindowHint | QtCore.Qt.WindowStaysOnTopHint
)
self.setWindowFlags(QtCore.Qt.FramelessWindowHint | QtCore.Qt.WindowStaysOnTopHint)
self.ui = Ui_MainWindow()
self.ui.setupUi(self)
@ -80,7 +74,7 @@ class UDSClient(QtWidgets.QMainWindow):
self.ui.progressBar.setValue(0)
self.ui.cancelButton.clicked.connect(self.cancelPushed)
self.ui.info.setText("Initializing...")
self.ui.info.setText('Initializing...')
screen = QtWidgets.QDesktopWidget().screenGeometry()
mysize = self.geometry()
@ -96,30 +90,27 @@ class UDSClient(QtWidgets.QMainWindow):
self.startAnim()
def closeWindow(self):
self.close()
def processError(self, data):
if "error" in data:
if 'error' in data:
# QtWidgets.QMessageBox.critical(self, 'Request error {}'.format(data.get('retryable', '0')), data['error'], QtWidgets.QMessageBox.Ok)
if data.get("retryable", "0") == "1":
raise RetryException(data["error"])
if data.get('retryable', '0') == '1':
raise RetryException(data['error'])
raise Exception(data["error"])
raise Exception(data['error'])
# QtWidgets.QMessageBox.critical(self, 'Request error', rest.data['error'], QtWidgets.QMessageBox.Ok)
# self.closeWindow()
# return
def showError(self, error):
logger.error("got error: %s", error)
logger.error('got error: %s', error)
self.stopAnim()
self.ui.info.setText(
"UDS Plugin Error"
) # In fact, main window is hidden, so this is not visible... :)
self.ui.info.setText('UDS Plugin Error') # In fact, main window is hidden, so this is not visible... :)
self.closeWindow()
QtWidgets.QMessageBox.critical(
None, "UDS Plugin Error", "{}".format(error), QtWidgets.QMessageBox.Ok
)
QtWidgets.QMessageBox.critical(None, 'UDS Plugin Error', '{}'.format(error), QtWidgets.QMessageBox.Ok)
self.withError = True
def cancelPushed(self):
@ -146,26 +137,21 @@ class UDSClient(QtWidgets.QMainWindow):
self.animTimer.stop()
def getVersion(self):
self.req = RestRequest("", self, self.version)
self.req = RestRequest('', self, self.version)
self.req.get()
def version(self, data):
try:
self.processError(data)
self.ui.info.setText("Processing...")
self.ui.info.setText('Processing...')
if data["result"]["requiredVersion"] > VERSION:
QtWidgets.QMessageBox.critical(
self,
"Upgrade required",
"A newer connector version is required.\nA browser will be opened to download it.",
QtWidgets.QMessageBox.Ok,
)
webbrowser.open(data["result"]["downloadUrl"])
if data['result']['requiredVersion'] > VERSION:
QtWidgets.QMessageBox.critical(self, 'Upgrade required', 'A newer connector version is required.\nA browser will be opened to download it.', QtWidgets.QMessageBox.Ok)
webbrowser.open(data['result']['downloadUrl'])
self.closeWindow()
return
self.serverVersion = data["result"]["requiredVersion"]
self.serverVersion = data['result']['requiredVersion']
self.getTransportData()
except RetryException as e:
@ -177,66 +163,55 @@ class UDSClient(QtWidgets.QMainWindow):
def getTransportData(self):
try:
self.req = RestRequest(
"/{}/{}".format(self.ticket, self.scrambler),
self,
self.transportDataReceived,
params={"hostname": tools.getHostName(), "version": VERSION},
)
self.req = RestRequest('/{}/{}'.format(self.ticket, self.scrambler), self, self.transportDataReceived, params={'hostname': tools.getHostName(), 'version': VERSION})
self.req.get()
except Exception as e:
logger.exception("Got exception on getTransportData")
logger.exception('Got exception on getTransportData')
raise e
def transportDataReceived(self, data):
logger.debug("Transport data received")
logger.debug('Transport data received')
try:
self.processError(data)
params = None
if self.serverVersion <= OLD_METHOD_VERSION:
script = bz2.decompress(base64.b64decode(data["result"]))
# This fixes uds 2.2 "write" string on binary streams on some transport
script = bz2.decompress(base64.b64decode(data['result']))
# This fixes uds 2.2 "write" string on binary streams on some transport
script = script.replace(b'stdin.write("', b'stdin.write(b"')
script = script.replace(b"version)", b'version.decode("utf-8"))')
script = script.replace(b'version)', b'version.decode("utf-8"))')
else:
res = data["result"]
res = data['result']
# We have three elements on result:
# * Script
# * Signature
# * Script data
# We test that the Script has correct signature, and them execute it with the parameters
# script, signature, params = res['script'].decode('base64').decode('bz2'), res['signature'], json.loads(res['params'].decode('base64').decode('bz2'))
script, signature, params = (
bz2.decompress(base64.b64decode(res["script"])),
res["signature"],
json.loads(bz2.decompress(base64.b64decode(res["params"]))),
)
#script, signature, params = res['script'].decode('base64').decode('bz2'), res['signature'], json.loads(res['params'].decode('base64').decode('bz2'))
script, signature, params = bz2.decompress(base64.b64decode(res['script'])), res['signature'], json.loads(bz2.decompress(base64.b64decode(res['params'])))
if tools.verifySignature(script, signature) is False:
logger.error("Signature is invalid")
logger.error('Signature is invalid')
raise Exception(
"Invalid UDS code signature. Please, report to administrator"
)
raise Exception('Invalid UDS code signature. Please, report to administrator')
self.stopAnim()
if "darwin" in sys.platform:
if 'darwin' in sys.platform:
self.showMinimized()
QtCore.QTimer.singleShot(3000, self.endScript)
self.hide()
six.exec_(script.decode("utf-8"), globals(), {"parent": self, "sp": params})
six.exec_(script.decode("utf-8"), globals(), {'parent': self, 'sp': params})
except RetryException as e:
self.ui.info.setText(six.text_type(e) + ", retrying access...")
self.ui.info.setText(six.text_type(e) + ', retrying access...')
# Retry operation in ten seconds
QtCore.QTimer.singleShot(10000, self.getTransportData)
except Exception as e:
# logger.exception('Got exception executing script:')
#logger.exception('Got exception executing script:')
self.showError(e)
def endScript(self):
@ -259,111 +234,84 @@ class UDSClient(QtWidgets.QMainWindow):
self.closeWindow()
def start(self):
"""
'''
Starts proccess by requesting version info
"""
self.ui.info.setText("Initializing...")
'''
self.ui.info.setText('Initializing...')
QtCore.QTimer.singleShot(100, self.getVersion)
def done(data):
QtWidgets.QMessageBox.critical(
None, "Notice", six.text_type(data.data), QtWidgets.QMessageBox.Ok
)
QtWidgets.QMessageBox.critical(None, 'Notice', six.text_type(data.data), QtWidgets.QMessageBox.Ok)
sys.exit(0)
# Ask user to approve endpoint
def approveHost(hostName, parentWindow=None):
settings = QtCore.QSettings()
settings.beginGroup("endpoints")
settings.beginGroup('endpoints')
# approved = settings.value(hostName, False).toBool()
#approved = settings.value(hostName, False).toBool()
approved = bool(settings.value(hostName, False))
errorString = "<p>The server <b>{}</b> must be approved:</p>".format(hostName)
errorString += (
"<p>Only approve UDS servers that you trust to avoid security issues.</p>"
)
errorString = '<p>The server <b>{}</b> must be approved:</p>'.format(hostName)
errorString += '<p>Only approve UDS servers that you trust to avoid security issues.</p>'
if (
approved
or QtWidgets.QMessageBox.warning(
parentWindow,
"ACCESS Warning",
errorString,
QtWidgets.QMessageBox.Yes | QtWidgets.QMessageBox.No,
)
== QtWidgets.QMessageBox.Yes
):
if approved or QtWidgets.QMessageBox.warning(parentWindow, 'ACCESS Warning', errorString, QtWidgets.QMessageBox.Yes | QtWidgets.QMessageBox.No) == QtWidgets.QMessageBox.Yes:
settings.setValue(hostName, True)
approved = True
settings.endGroup()
return approved
if __name__ == "__main__":
logger.debug("Initializing connector")
logger.debug('Initializing connector')
# Initialize app
app = QtWidgets.QApplication(sys.argv)
# Set several info for settings
QtCore.QCoreApplication.setOrganizationName("Virtual Cable S.L.U.")
QtCore.QCoreApplication.setApplicationName("UDS Connector")
QtCore.QCoreApplication.setOrganizationName('Virtual Cable S.L.U.')
QtCore.QCoreApplication.setApplicationName('UDS Connector')
if "darwin" not in sys.platform:
logger.debug("Mac OS *NOT* Detected")
app.setStyle("plastique")
if 'darwin' not in sys.platform:
logger.debug('Mac OS *NOT* Detected')
app.setStyle('plastique')
if six.PY3 is False:
logger.debug("Fixing threaded execution of commands")
logger.debug('Fixing threaded execution of commands')
import threading
threading._DummyThread._Thread__stop = lambda x: 42 # type: ignore # pylint: disable=protected-access
# First parameter must be url
try:
uri = sys.argv[1]
if uri == "--test":
if uri == '--test':
sys.exit(0)
logger.debug("URI: %s", uri)
if uri[:6] != "uds://" and uri[:7] != "udss://":
logger.debug('URI: %s', uri)
if uri[:6] != 'uds://' and uri[:7] != 'udss://':
raise Exception()
ssl = uri[3] == "s"
host, UDSClient.ticket, UDSClient.scrambler = uri.split("//")[1].split("/") # type: ignore
logger.debug(
"ssl:%s, host:%s, ticket:%s, scrambler:%s",
ssl,
host,
UDSClient.ticket,
UDSClient.scrambler,
)
ssl = uri[3] == 's'
host, UDSClient.ticket, UDSClient.scrambler = uri.split('//')[1].split('/') # type: ignore
logger.debug('ssl:%s, host:%s, ticket:%s, scrambler:%s', ssl, host, UDSClient.ticket, UDSClient.scrambler)
except Exception:
logger.debug("Detected execution without valid URI, exiting")
QtWidgets.QMessageBox.critical(
None,
"Notice",
"UDS Client Version {}".format(VERSION),
QtWidgets.QMessageBox.Ok,
)
logger.debug('Detected execution without valid URI, exiting')
QtWidgets.QMessageBox.critical(None, 'Notice', 'UDS Client Version {}'.format(VERSION), QtWidgets.QMessageBox.Ok)
sys.exit(1)
# Setup REST api endpoint
RestRequest.restApiUrl = "{}://{}/rest/client".format(["http", "https"][ssl], host)
logger.debug("Setting request URL to %s", RestRequest.restApiUrl)
RestRequest.restApiUrl = '{}://{}/rest/client'.format(['http', 'https'][ssl], host)
logger.debug('Setting request URL to %s', RestRequest.restApiUrl)
# RestRequest.restApiUrl = 'https://172.27.0.1/rest/client'
try:
logger.debug("Starting execution")
logger.debug('Starting execution')
# Approbe before going on
if approveHost(host) is False:
raise Exception("Host {} was not approved".format(host))
raise Exception('Host {} was not approved'.format(host))
win = UDSClient()
win.show()
@ -371,14 +319,12 @@ if __name__ == "__main__":
win.start()
exitVal = app.exec_()
logger.debug("Execution finished correctly")
logger.debug('Execution finished correctly')
except Exception as e:
logger.exception("Got an exception executing client:")
logger.exception('Got an exception executing client:')
exitVal = 128
QtWidgets.QMessageBox.critical(
None, "Error", six.text_type(e), QtWidgets.QMessageBox.Ok
)
QtWidgets.QMessageBox.critical(None, 'Error', six.text_type(e), QtWidgets.QMessageBox.Ok)
logger.debug("Exiting")
logger.debug('Exiting')
sys.exit(exitVal)

View File

@ -0,0 +1,48 @@
# -*- coding: utf-8 -*-
#
# Copyright (c) 2017-2021 Virtual Cable S.L.U.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of Virtual Cable S.L. nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
'''
@author: Adolfo Gómez, dkmaster at dkmon dot com
'''
from __future__ import unicode_literals
import sys
LINUX = 'Linux'
WINDOWS = 'Windows'
MAC_OS_X = 'Mac os x'
def getOs():
if sys.platform.startswith('linux'):
return LINUX
if sys.platform.startswith('win'):
return WINDOWS
if sys.platform.startswith('darwin'):
return MAC_OS_X
return 'other'

View File

@ -11,7 +11,7 @@
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of Virtual Cable S.L.U. nor the names of its contributors
# * Neither the name of Virtual Cable S.L. nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
@ -32,27 +32,22 @@
# pylint: disable=c-extension-no-member,no-name-in-module
import json
import os
import urllib
import urllib.parse
import certifi
from PyQt5.QtCore import pyqtSignal
from PyQt5.QtCore import QObject, QUrl, QSettings
from PyQt5.QtCore import Qt
from PyQt5.QtNetwork import (
QNetworkAccessManager,
QNetworkRequest,
QNetworkReply,
QSslCertificate,
)
from PyQt5.QtNetwork import QNetworkAccessManager, QNetworkRequest, QNetworkReply, QSslCertificate
from PyQt5.QtWidgets import QMessageBox
from . import os_detector
from . import osDetector
from . import VERSION
class RestRequest(QObject):
restApiUrl = '' #
@ -63,12 +58,16 @@ class RestRequest(QObject):
super(RestRequest, self).__init__()
# private
self._manager = QNetworkAccessManager()
try:
if os.path.exists('/etc/ssl/certs/ca-certificates.crt'):
pass
# os.environ['REQUESTS_CA_BUNDLE'] = '/etc/ssl/certs/ca-certificates.crt'
except Exception:
pass
if params is not None:
url += '?' + '&'.join(
'{}={}'.format(k, urllib.parse.quote(str(v).encode('utf8')))
for k, v in params.items()
)
url += '?' + '&'.join('{}={}'.format(k, urllib.parse.quote(str(v).encode('utf8'))) for k, v in params.items())
self.url = QUrl(RestRequest.restApiUrl + url)
@ -77,21 +76,24 @@ class RestRequest(QObject):
self._manager.sslErrors.connect(self._sslError)
self._parentWindow = parentWindow
self.done.connect(done, Qt.QueuedConnection)
self.done.connect(done, Qt.QueuedConnection) # type: ignore
def _finished(self, reply):
"""
'''
Handle signal 'finished'. A network request has finished.
"""
'''
try:
if reply.error() != QNetworkReply.NoError:
raise Exception(reply.errorString())
data = bytes(reply.readAll())
data = json.loads(data)
except Exception as e:
data = {'result': None, 'error': str(e)}
data = {
'result': None,
'error': str(e)
}
self.done.emit(data)
self.done.emit(data) # type: ignore
reply.deleteLater() # schedule for delete from main event loop
@ -103,27 +105,14 @@ class RestRequest(QObject):
approved = settings.value(digest, False)
errorString = (
'<p>The certificate for <b>{}</b> has the following errors:</p><ul>'.format(
cert.subjectInfo(QSslCertificate.CommonName)
)
)
errorString = '<p>The certificate for <b>{}</b> has the following errors:</p><ul>'.format(cert.subjectInfo(QSslCertificate.CommonName))
for err in errors:
errorString += '<li>' + err.errorString() + '</li>'
errorString += '</ul>'
if (
approved
or QMessageBox.warning(
self._parentWindow,
'SSL Warning',
errorString,
QMessageBox.Yes | QMessageBox.No,
)
== QMessageBox.Yes
):
if approved or QMessageBox.warning(self._parentWindow, 'SSL Warning', errorString, QMessageBox.Yes | QMessageBox.No) == QMessageBox.Yes: # type: ignore
settings.setValue(digest, True)
reply.ignoreSslErrors()
@ -131,14 +120,5 @@ class RestRequest(QObject):
def get(self):
request = QNetworkRequest(self.url)
# Ensure loads certifi certificates
sslCfg = request.sslConfiguration()
sslCfg.addCaCertificates(certifi.where())
request.setSslConfiguration(sslCfg)
request.setRawHeader(
b'User-Agent',
os_detector.getOs().encode('utf-8')
+ b" - UDS Connector "
+ VERSION.encode('utf-8'),
)
request.setRawHeader(b'User-Agent', osDetector.getOs().encode('utf-8') + b" - UDS Connector " + VERSION.encode('utf-8'))
self._manager.get(request)

View File

@ -11,7 +11,7 @@
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of Virtual Cable S.L.U. nor the names of its contributors
# * Neither the name of Virtual Cable S.L. nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
@ -31,20 +31,18 @@
'''
from __future__ import unicode_literals
import base64
import os
from base64 import b64decode
import tempfile
import string
import random
import os
import socket
import stat
import string
import sys
import tempfile
import time
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.serialization import load_pem_public_key
import six
from .log import logger
@ -52,8 +50,10 @@ _unlinkFiles = []
_tasksToWait = []
_execBeforeExit = []
sys_fs_enc = sys.getfilesystemencoding() or 'mbcs'
# Public key for scripts
PUBLIC_KEY = b'''-----BEGIN PUBLIC KEY-----
PUBLIC_KEY = '''-----BEGIN PUBLIC KEY-----
MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAuNURlGjBpqbglkTTg2lh
dU5qPbg9Q+RofoDDucGfrbY0pjB9ULgWXUetUWDZhFG241tNeKw+aYFTEorK5P+g
ud7h9KfyJ6huhzln9eyDu3k+kjKUIB1PLtA3lZLZnBx7nmrHRody1u5lRaLVplsb
@ -71,19 +71,13 @@ nVgtClKcDDlSaBsO875WDR0CAwEAAQ==
def saveTempFile(content, filename=None):
if filename is None:
filename = ''.join(
random.choice(string.ascii_lowercase + string.digits) for _ in range(16)
)
filename = ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(16))
filename = filename + '.uds'
filename = os.path.join(tempfile.gettempdir(), filename)
try:
with open(filename, 'w') as f:
f.write(content)
except Exception as e:
logger.error('Error saving temporary file %s: %s', filename, e)
raise
with open(filename, 'w') as f:
f.write(content)
logger.info('Returning filename')
return filename
@ -94,8 +88,7 @@ def readTempFile(filename):
try:
with open(filename, 'r') as f:
return f.read()
except Exception as e:
logger.warning('Could not read file %s: %s', filename, e)
except Exception:
return None
@ -117,34 +110,32 @@ def findApp(appName, extraPath=None):
fileName = os.path.join(path, appName)
if os.path.isfile(fileName) and (os.stat(fileName).st_mode & stat.S_IXUSR) != 0:
return fileName
logger.warning('Application %s not found on path %s', appName, searchPath)
return None
def getHostName():
"""
'''
Returns current host name
In fact, it's a wrapper for socket.gethostname()
"""
'''
hostname = socket.gethostname()
logger.info('Hostname: %s', hostname)
return hostname
# Queing operations (to be executed before exit)
def addFileToUnlink(filename):
"""
'''
Adds a file to the wait-and-unlink list
"""
'''
_unlinkFiles.append(filename)
def unlinkFiles():
"""
'''
Removes all wait-and-unlink files
"""
'''
if _unlinkFiles:
time.sleep(5) # Wait 5 seconds before deleting anything
@ -180,14 +171,21 @@ def execBeforeExit():
def verifySignature(script, signature):
public_key = load_pem_public_key(backend=default_backend(), data=PUBLIC_KEY)
'''
Verifies with a public key from whom the data came that it was indeed
signed by their private key
param: public_key_loc Path to public key
param: signature String signature to be verified
return: Boolean. True if the signature is valid; False otherwise.
'''
# For signature checking
from Crypto.PublicKey import RSA
from Crypto.Signature import PKCS1_v1_5
from Crypto.Hash import SHA256
# Message option
try:
public_key.verify(
base64.b64decode(signature), script, padding.PKCS1v15(), hashes.SHA256()
)
except Exception: # InvalidSignature
logger.error('Invalid signature for UDS plugin code. Contact Administrator.')
return False
return True
rsakey = RSA.importKey(PUBLIC_KEY)
signer = PKCS1_v1_5.new(rsakey)
digest = SHA256.new(script) # Script is "binary string" here
if signer.verify(digest, b64decode(signature)):
return True
return False

View File

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
#
# Copyright (c) 2021 Virtual Cable S.L.U.
# Copyright (c) 2020 Virtual Cable S.L.U.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
@ -11,7 +11,7 @@
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of Virtual Cable S.L.U. nor the names of its contributors
# * Neither the name of Virtual Cable S.L. nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
@ -33,7 +33,6 @@ import socketserver
import ssl
import threading
import time
import random
import threading
import select
import typing
@ -49,7 +48,6 @@ TUNNEL_LISTENING, TUNNEL_OPENING, TUNNEL_PROCESSING, TUNNEL_ERROR = 0, 1, 2, 3
logger = logging.getLogger(__name__)
class ForwardServer(socketserver.ThreadingTCPServer):
daemon_threads = True
allow_reuse_address = True
@ -57,7 +55,6 @@ class ForwardServer(socketserver.ThreadingTCPServer):
remote: typing.Tuple[str, int]
ticket: str
stop_flag: threading.Event
can_stop: bool
timeout: int
timer: typing.Optional[threading.Timer]
check_certificate: bool
@ -72,30 +69,23 @@ class ForwardServer(socketserver.ThreadingTCPServer):
local_port: int = 0,
check_certificate: bool = True,
) -> None:
local_port = local_port or random.randrange(33000, 53000)
super().__init__(
server_address=(LISTEN_ADDRESS, local_port), RequestHandlerClass=Handler
)
self.remote = remote
self.ticket = ticket
# Negative values for timeout, means "accept always connections"
# "but if no connection is stablished on timeout (positive)"
# "stop the listener"
self.timeout = int(time.time()) + timeout if timeout > 0 else 0
self.timeout = int(time.time()) + timeout if timeout else 0
self.check_certificate = check_certificate
self.stop_flag = threading.Event() # False initial
self.current_connections = 0
self.status = TUNNEL_LISTENING
self.can_stop = False
timeout = abs(timeout) or 60
self.timer = threading.Timer(
abs(timeout), ForwardServer.__checkStarted, args=(self,)
)
self.timer.start()
if timeout:
self.timer = threading.Timer(timeout, ForwardServer.__checkStarted, args=(self,))
self.timer.start()
else:
self.timer = None
def stop(self) -> None:
if not self.stop_flag.is_set():
@ -106,52 +96,13 @@ class ForwardServer(socketserver.ThreadingTCPServer):
self.timer = None
self.shutdown()
def connect(self) -> ssl.SSLSocket:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as rsocket:
logger.info('CONNECT to %s', self.remote)
rsocket.connect(self.remote)
context = ssl.create_default_context()
# Do not "recompress" data, use only "base protocol" compression
context.options |= ssl.OP_NO_COMPRESSION
# If ignore remote certificate
if self.check_certificate is False:
context.check_hostname = False
context.verify_mode = ssl.CERT_NONE
logger.warning('Certificate checking is disabled!')
return context.wrap_socket(rsocket, server_hostname=self.remote[0])
def check(self) -> bool:
if self.status == TUNNEL_ERROR:
return False
try:
with self.connect() as ssl_socket:
ssl_socket.sendall(HANDSHAKE_V1 + b'TEST')
resp = ssl_socket.recv(2)
if resp != b'OK':
raise Exception({'Invalid tunnelresponse: {resp}'})
return True
except Exception as e:
logger.error(
'Error connecting to tunnel server %s: %s', self.server_address, e
)
return False
@property
def stoppable(self) -> bool:
logger.debug('Is stoppable: %s', self.can_stop)
return self.can_stop or (self.timeout != 0 and int(time.time()) > self.timeout)
return self.timeout != 0 and int(time.time()) > self.timeout
@staticmethod
def __checkStarted(fs: 'ForwardServer') -> None:
logger.debug('New connection limit reached')
fs.timer = None
fs.can_stop = True
if fs.current_connections <= 0:
fs.stop()
@ -162,33 +113,46 @@ class Handler(socketserver.BaseRequestHandler):
# server: ForwardServer
def handle(self) -> None:
self.server.current_connections += 1
self.server.status = TUNNEL_OPENING
# If server processing is over time
# If server processing is over time
if self.server.stoppable:
self.server.status = TUNNEL_ERROR
logger.info('Rejected timedout connection')
logger.info('Rejected timedout connection try')
self.request.close() # End connection without processing it
return
self.server.current_connections += 1
# Open remote connection
try:
logger.debug('Ticket %s', self.server.ticket)
with self.server.connect() as ssl_socket:
# Send handhshake + command + ticket
ssl_socket.sendall(HANDSHAKE_V1 + b'OPEN' + self.server.ticket.encode())
# Check response is OK
data = ssl_socket.recv(2)
if data != b'OK':
data += ssl_socket.recv(128)
raise Exception(
f'Error received: {data.decode(errors="ignore")}'
) # Notify error
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as rsocket:
logger.info('CONNECT to %s', self.server.remote)
logger.debug('Ticket %s', self.server.ticket)
# All is fine, now we can tunnel data
self.process(remote=ssl_socket)
rsocket.connect(self.server.remote)
context = ssl.create_default_context()
# If ignore remote certificate
if self.server.check_certificate is False:
context.check_hostname = False
context.verify_mode = ssl.CERT_NONE
logger.warning('Certificate checking is disabled!')
with context.wrap_socket(
rsocket, server_hostname=self.server.remote[0]
) as ssl_socket:
# Send handhshake + command + ticket
ssl_socket.sendall(
HANDSHAKE_V1 + b'OPEN' + self.server.ticket.encode()
)
# Check response is OK
data = ssl_socket.recv(2)
if data != b'OK':
data += ssl_socket.recv(128)
raise Exception(f'Error received: {data.decode()}') # Notify error
# All is fine, now we can tunnel data
self.process(remote=ssl_socket)
except Exception as e:
logger.error(f'Error connecting to {self.server.remote!s}: {e!s}')
self.server.status = TUNNEL_ERROR
@ -221,17 +185,10 @@ class Handler(socketserver.BaseRequestHandler):
except Exception as e:
pass
def _run(server: ForwardServer) -> None:
logger.debug(
'Starting forwarder: %s -> %s, timeout: %d',
server.server_address,
server.remote,
server.timeout,
)
logger.debug('Starting forwarder: %s -> %s, timeout: %d', server.server_address, server.remote, server.timeout)
server.serve_forever()
logger.debug('Stoped forwarder %s -> %s', server.server_address, server.remote)
logger.debug('Stoped forwarded %s -> %s', server.server_address, server.remote)
def forward(
remote: typing.Tuple[str, int],
@ -240,7 +197,6 @@ def forward(
local_port: int = 0,
check_certificate=True,
) -> ForwardServer:
fs = ForwardServer(
remote=remote,
ticket=ticket,
@ -252,3 +208,17 @@ def forward(
threading.Thread(target=_run, args=(fs,)).start()
return fs
if __name__ == "__main__":
import sys
log = logging.getLogger()
log.setLevel(logging.DEBUG)
handler = logging.StreamHandler(sys.stdout)
handler.setLevel(logging.DEBUG)
formatter = logging.Formatter(
'%(levelname)s - %(message)s'
) # Basic log format, nice for syslog
handler.setFormatter(formatter)
log.addHandler(handler)
fs = forward(('172.27.0.1', 7777), '1'*64, local_port=49999, timeout=10, check_certificate=False)