UDS 3.4 now uses volumev3 for non legacy openstack connections (legacy maintains v2)

This commit is contained in:
Adolfo Gómez García 2021-08-11 18:59:18 +02:00
parent 1be49a6e0e
commit 68411f0726
7 changed files with 157 additions and 54 deletions

View File

@ -67,7 +67,7 @@ if __name__ == "__main__":
# Note: Signals are only checked on python code execution, so we create a timer to force call back to python # Note: Signals are only checked on python code execution, so we create a timer to force call back to python
timer = QTimer(qApp) timer = QTimer(qApp)
timer.start(1000) timer.start(1000)
timer.timeout.connect(lambda *a: None) # type: ignore # timeout can be connected to a callable timer.timeout.connect(lambda *a: None) # timeout can be connected to a callable
qApp.exec_() qApp.exec_()

View File

@ -25,7 +25,11 @@ class CheckfingerPrints(paramiko.MissingHostKeyPolicy):
if self.fingerPrints: if self.fingerPrints:
remotefingerPrints = hexlify(key.get_fingerprint()).decode().lower() remotefingerPrints = hexlify(key.get_fingerprint()).decode().lower()
if remotefingerPrints not in self.fingerPrints.split(','): if remotefingerPrints not in self.fingerPrints.split(','):
logger.error("Server {!r} has invalid fingerPrints. ({} vs {})".format(hostname, remotefingerPrints, self.fingerPrints)) logger.error(
"Server {!r} has invalid fingerPrints. ({} vs {})".format(
hostname, remotefingerPrints, self.fingerPrints
)
)
raise paramiko.SSHException( raise paramiko.SSHException(
"Server {!r} has invalid fingerPrints".format(hostname) "Server {!r} has invalid fingerPrints".format(hostname)
) )
@ -47,21 +51,39 @@ class Handler(socketserver.BaseRequestHandler):
self.thread.currentConnections += 1 self.thread.currentConnections += 1
try: try:
chan = self.ssh_transport.open_channel('direct-tcpip', chan = self.ssh_transport.open_channel(
'direct-tcpip',
(self.chain_host, self.chain_port), (self.chain_host, self.chain_port),
self.request.getpeername()) self.request.getpeername(),
)
except Exception as e: except Exception as e:
logger.exception('Incoming request to %s:%d failed: %s', self.chain_host, self.chain_port, repr(e)) logger.exception(
'Incoming request to %s:%d failed: %s',
self.chain_host,
self.chain_port,
repr(e),
)
return return
if chan is None: if chan is None:
logger.error('Incoming request to %s:%d was rejected by the SSH server.', self.chain_host, self.chain_port) logger.error(
'Incoming request to %s:%d was rejected by the SSH server.',
self.chain_host,
self.chain_port,
)
return return
logger.debug('Connected! Tunnel open %r -> %r -> %r', self.request.getpeername(), chan.getpeername(), (self.chain_host, self.chain_port)) logger.debug(
'Connected! Tunnel open %r -> %r -> %r',
self.request.getpeername(),
chan.getpeername(),
(self.chain_host, self.chain_port),
)
# self.ssh_transport.set_keepalive(10) # Keep alive every 10 seconds... # self.ssh_transport.set_keepalive(10) # Keep alive every 10 seconds...
try: try:
while self.event.is_set() is False: while self.event.is_set() is False:
r, _w, _x = select.select([self.request, chan], [], [], 1) # pylint: disable=unused-variable r, _w, _x = select.select(
[self.request, chan], [], [], 1
) # pylint: disable=unused-variable
if self.request in r: if self.request in r:
data = self.request.recv(1024) data = self.request.recv(1024)
@ -80,7 +102,10 @@ class Handler(socketserver.BaseRequestHandler):
peername = self.request.getpeername() peername = self.request.getpeername()
chan.close() chan.close()
self.request.close() self.request.close()
logger.debug('Tunnel closed from %r', peername,) logger.debug(
'Tunnel closed from %r',
peername,
)
except Exception: except Exception:
pass pass
@ -95,7 +120,18 @@ class ForwardThread(threading.Thread):
client: typing.Optional[paramiko.SSHClient] client: typing.Optional[paramiko.SSHClient]
fs: typing.Optional[ForwardServer] fs: typing.Optional[ForwardServer]
def __init__(self, server, port, username, password, localPort, redirectHost, redirectPort, waitTime, fingerPrints): def __init__(
self,
server,
port,
username,
password,
localPort,
redirectHost,
redirectPort,
waitTime,
fingerPrints,
):
threading.Thread.__init__(self) threading.Thread.__init__(self)
self.client = None self.client = None
self.fs = None self.fs = None
@ -124,7 +160,17 @@ class ForwardThread(threading.Thread):
if localPort is None: if localPort is None:
localPort = random.randrange(33000, 53000) localPort = random.randrange(33000, 53000)
ft = ForwardThread(self.server, self.port, self.username, self.password, localPort, redirectHost, redirectPort, self.waitTime, self.fingerPrints) ft = ForwardThread(
self.server,
self.port,
self.username,
self.password,
localPort,
redirectHost,
redirectPort,
self.waitTime,
self.fingerPrints,
)
ft.client = self.client ft.client = self.client
self.client.useCount += 1 # type: ignore self.client.useCount += 1 # type: ignore
ft.start() ft.start()
@ -134,7 +180,6 @@ class ForwardThread(threading.Thread):
return (ft, localPort) return (ft, localPort)
def _timerFnc(self): def _timerFnc(self):
self.timer = None self.timer = None
logger.debug('Timer fnc: %s', self.currentConnections) logger.debug('Timer fnc: %s', self.currentConnections)
@ -148,12 +193,21 @@ class ForwardThread(threading.Thread):
self.client = paramiko.SSHClient() self.client = paramiko.SSHClient()
self.client.useCount = 1 # type: ignore self.client.useCount = 1 # type: ignore
self.client.load_system_host_keys() self.client.load_system_host_keys()
self.client.set_missing_host_key_policy(CheckfingerPrints(self.fingerPrints)) self.client.set_missing_host_key_policy(
CheckfingerPrints(self.fingerPrints)
)
logger.debug('Connecting to ssh host %s:%d ...', self.server, self.port) logger.debug('Connecting to ssh host %s:%d ...', self.server, self.port)
# To disable ssh-ageng asking for passwords: allow_agent=False # To disable ssh-ageng asking for passwords: allow_agent=False
self.client.connect(self.server, self.port, username=self.username, password=self.password, timeout=5, allow_agent=False) self.client.connect(
self.server,
self.port,
username=self.username,
password=self.password,
timeout=5,
allow_agent=False,
)
except Exception: except Exception:
logger.exception('Exception connecting: ') logger.exception('Exception connecting: ')
self.status = 2 # Error self.status = 2 # Error
@ -194,7 +248,17 @@ class ForwardThread(threading.Thread):
logger.exception('Exception stopping') logger.exception('Exception stopping')
def forward(server, port, username, password, redirectHost, redirectPort, localPort=None, waitTime=10, fingerPrints=None): def forward(
server,
port,
username,
password,
redirectHost,
redirectPort,
localPort=None,
waitTime=10,
fingerPrints=None,
):
''' '''
Instantiates an ssh connection to server:port Instantiates an ssh connection to server:port
Returns the Thread created and the local redirected port as a list: (thread, port) Returns the Thread created and the local redirected port as a list: (thread, port)
@ -204,10 +268,28 @@ def forward(server, port, username, password, redirectHost, redirectPort, localP
if localPort is None: if localPort is None:
localPort = random.randrange(40000, 50000) localPort = random.randrange(40000, 50000)
logger.debug('Connecting to %s:%s using %s/%s redirecting to %s:%s, listening on 127.0.0.1:%s', logger.debug(
server, port, username, password, redirectHost, redirectPort, localPort) 'Connecting to %s:%s using %s/%s redirecting to %s:%s, listening on 127.0.0.1:%s',
server,
port,
username,
password,
redirectHost,
redirectPort,
localPort,
)
ft = ForwardThread(server, port, username, password, localPort, redirectHost, redirectPort, waitTime, fingerPrints) ft = ForwardThread(
server,
port,
username,
password,
localPort,
redirectHost,
redirectPort,
waitTime,
fingerPrints,
)
ft.start() ft.start()

View File

@ -29,8 +29,6 @@
''' '''
@author: Adolfo Gómez, dkmaster at dkmon dot com @author: Adolfo Gómez, dkmaster at dkmon dot com
''' '''
from __future__ import unicode_literals
import logging import logging
import os import os
import os.path import os.path
@ -57,7 +55,7 @@ try:
filename=logFile, filename=logFile,
filemode='a', filemode='a',
format='%(levelname)s %(asctime)s %(message)s', format='%(levelname)s %(asctime)s %(message)s',
level=LOGLEVEL level=LOGLEVEL,
) )
except Exception: except Exception:
logging.basicConfig(format='%(levelname)s %(asctime)s %(message)s', level=LOGLEVEL) logging.basicConfig(format='%(levelname)s %(asctime)s %(message)s', level=LOGLEVEL)

View File

@ -30,14 +30,13 @@
''' '''
@author: Adolfo Gómez, dkmaster at dkmon dot com @author: Adolfo Gómez, dkmaster at dkmon dot com
''' '''
from __future__ import unicode_literals
import sys import sys
LINUX = 'Linux' LINUX = 'Linux'
WINDOWS = 'Windows' WINDOWS = 'Windows'
MAC_OS_X = 'Mac os x' MAC_OS_X = 'Mac os x'
def getOs(): def getOs():
if sys.platform.startswith('linux'): if sys.platform.startswith('linux'):
return LINUX return LINUX

View File

@ -29,8 +29,6 @@
''' '''
@author: Adolfo Gómez, dkmaster at dkmon dot com @author: Adolfo Gómez, dkmaster at dkmon dot com
''' '''
# pylint: disable=c-extension-no-member,no-name-in-module
import json import json
import bz2 import bz2
import base64 import base64
@ -63,9 +61,11 @@ CertCallbackType = typing.Callable[[str, str], bool]
class UDSException(Exception): class UDSException(Exception):
pass pass
class RetryException(UDSException): class RetryException(UDSException):
pass pass
class InvalidVersion(UDSException): class InvalidVersion(UDSException):
downloadUrl: str downloadUrl: str
@ -73,6 +73,7 @@ class InvalidVersion(UDSException):
super().__init__(downloadUrl) super().__init__(downloadUrl)
self.downloadUrl = downloadUrl self.downloadUrl = downloadUrl
class RestApi: class RestApi:
_restApiUrl: str # base Rest API URL _restApiUrl: str # base Rest API URL
@ -90,14 +91,18 @@ class RestApi:
self._callbackInvalidCert = callbackInvalidCert self._callbackInvalidCert = callbackInvalidCert
self._serverVersion = '' self._serverVersion = ''
def get(self, url: str, params: typing.Optional[typing.Mapping[str, str]] = None) -> typing.Any: def get(
self, url: str, params: typing.Optional[typing.Mapping[str, str]] = None
) -> typing.Any:
if params: if params:
url += '?' + '&'.join( url += '?' + '&'.join(
'{}={}'.format(k, urllib.parse.quote(str(v).encode('utf8'))) '{}={}'.format(k, urllib.parse.quote(str(v).encode('utf8')))
for k, v in params.items() for k, v in params.items()
) )
return json.loads(RestApi.getUrl(self._restApiUrl + url, self._callbackInvalidCert)) return json.loads(
RestApi.getUrl(self._restApiUrl + url, self._callbackInvalidCert)
)
def processError(self, data: typing.Any) -> None: def processError(self, data: typing.Any) -> None:
if 'error' in data: if 'error' in data:
@ -106,7 +111,6 @@ class RestApi:
raise UDSException(data['error']) raise UDSException(data['error'])
def getVersion(self) -> str: def getVersion(self) -> str:
'''Gets and stores the serverVersion. '''Gets and stores the serverVersion.
Also checks that the version is valid for us. If not, Also checks that the version is valid for us. If not,
@ -127,7 +131,9 @@ class RestApi:
except Exception as e: except Exception as e:
raise UDSException(e) raise UDSException(e)
def getScriptAndParams(self, ticket: str, scrambler: str) -> typing.Tuple[str, typing.Any]: def getScriptAndParams(
self, ticket: str, scrambler: str
) -> typing.Tuple[str, typing.Any]:
'''Gets the transport script, validates it if necesary '''Gets the transport script, validates it if necesary
and returns it''' and returns it'''
try: try:
@ -173,7 +179,6 @@ class RestApi:
# exec(script.decode("utf-8"), globals(), {'parent': self, 'sp': params}) # exec(script.decode("utf-8"), globals(), {'parent': self, 'sp': params})
@staticmethod @staticmethod
def _open( def _open(
url: str, certErrorCallback: typing.Optional[CertCallbackType] = None url: str, certErrorCallback: typing.Optional[CertCallbackType] = None
@ -193,7 +198,8 @@ class RestApi:
if url.startswith('https'): if url.startswith('https'):
port = port or '443' port = port or '443'
with ctx.wrap_socket( with ctx.wrap_socket(
socket.socket(socket.AF_INET, socket.SOCK_STREAM), server_hostname=hostname socket.socket(socket.AF_INET, socket.SOCK_STREAM),
server_hostname=hostname,
) as s: ) as s:
s.connect((hostname, int(port))) s.connect((hostname, int(port)))
# Get binary certificate # Get binary certificate
@ -211,9 +217,12 @@ class RestApi:
def urlopen(url: str): def urlopen(url: str):
# Generate the request with the headers # Generate the request with the headers
req = urllib.request.Request(url, headers={ req = urllib.request.Request(
url,
headers={
'User-Agent': os_detector.getOs() + " - UDS Connector " + VERSION 'User-Agent': os_detector.getOs() + " - UDS Connector " + VERSION
}) },
)
return urllib.request.urlopen(req, context=ctx) return urllib.request.urlopen(req, context=ctx)
try: try:

View File

@ -163,7 +163,9 @@ def unlinkFiles(early: bool = False) -> None:
def addTaskToWait(task: typing.Any, includeSubprocess: bool = False) -> None: def addTaskToWait(task: typing.Any, includeSubprocess: bool = False) -> None:
logger.debug( logger.debug(
'Added task %s to wait %s', task, 'with subprocesses' if includeSubprocess else '' 'Added task %s to wait %s',
task,
'with subprocesses' if includeSubprocess else '',
) )
_tasksToWait.append((task, includeSubprocess)) _tasksToWait.append((task, includeSubprocess))
@ -178,12 +180,22 @@ def waitForTasks() -> None:
elif hasattr(task, 'wait'): elif hasattr(task, 'wait'):
task.wait() task.wait()
# If wait for spanwed process (look for process with task pid) and we can look for them... # If wait for spanwed process (look for process with task pid) and we can look for them...
logger.debug('Psutil: %s, waitForSubp: %s, hasattr: %s', psutil, waitForSubp, hasattr(task, 'pid')) logger.debug(
'Psutil: %s, waitForSubp: %s, hasattr: %s',
psutil,
waitForSubp,
hasattr(task, 'pid'),
)
if psutil and waitForSubp and hasattr(task, 'pid'): if psutil and waitForSubp and hasattr(task, 'pid'):
subProcesses = list(filter( subProcesses = list(
lambda x: x.ppid() == task.pid, psutil.process_iter(attrs=('ppid',)) filter(
)) lambda x: x.ppid() == task.pid, # type: ignore
logger.debug('Waiting for subprocesses... %s, %s', task.pid, subProcesses) psutil.process_iter(attrs=('ppid',)),
)
)
logger.debug(
'Waiting for subprocesses... %s, %s', task.pid, subProcesses
)
for i in subProcesses: for i in subProcesses:
logger.debug('Found %s', i) logger.debug('Found %s', i)
i.wait() i.wait()
@ -229,6 +241,7 @@ def verifySignature(script: bytes, signature: bytes) -> bool:
# If no exception, the script was fine... # If no exception, the script was fine...
return True return True
def getCaCertsFile() -> str: def getCaCertsFile() -> str:
try: try:
if os.path.exists(certifi.where()): if os.path.exists(certifi.where()):

View File

@ -145,6 +145,7 @@ class Client: # pylint: disable=too-many-public-methods
_tokenId: typing.Optional[str] _tokenId: typing.Optional[str]
_catalog: typing.Optional[typing.List[typing.Dict[str, typing.Any]]] _catalog: typing.Optional[typing.List[typing.Dict[str, typing.Any]]]
_isLegacy: bool _isLegacy: bool
_volume: str
_access: typing.Optional[str] _access: typing.Optional[str]
_domain: str _domain: str
_username: str _username: str
@ -188,6 +189,7 @@ class Client: # pylint: disable=too-many-public-methods
self._project = None self._project = None
self._region = region self._region = region
self._timeout = 10 self._timeout = 10
self._volume = 'volumev2' if self._isLegacy else 'volumev3'
if legacyVersion: if legacyVersion:
self._authUrl = 'http{}://{}:{}/'.format('s' if useSSL else '', host, port) self._authUrl = 'http{}://{}:{}/'.format('s' if useSSL else '', host, port)
@ -271,6 +273,7 @@ class Client: # pylint: disable=too-many-public-methods
if self._projectId is not None: if self._projectId is not None:
self._catalog = token['catalog'] self._catalog = token['catalog']
def ensureAuthenticated(self) -> None: def ensureAuthenticated(self) -> None:
if ( if (
self._authenticated is False self._authenticated is False
@ -331,7 +334,7 @@ class Client: # pylint: disable=too-many-public-methods
@authProjectRequired @authProjectRequired
def listVolumeTypes(self) -> typing.Iterable[typing.Any]: def listVolumeTypes(self) -> typing.Iterable[typing.Any]:
return getRecurringUrlJson( return getRecurringUrlJson(
self._getEndpointFor('volumev2') + '/types', self._getEndpointFor(self._volume) + '/types',
self._session, self._session,
headers=self._requestHeaders(), headers=self._requestHeaders(),
key='volume_types', key='volume_types',
@ -341,9 +344,8 @@ class Client: # pylint: disable=too-many-public-methods
@authProjectRequired @authProjectRequired
def listVolumes(self) -> typing.Iterable[typing.Any]: def listVolumes(self) -> typing.Iterable[typing.Any]:
# self._getEndpointFor('volumev2') + '/volumes'
return getRecurringUrlJson( return getRecurringUrlJson(
self._getEndpointFor('volumev2') + '/volumes/detail', self._getEndpointFor(self._volume) + '/volumes/detail',
self._session, self._session,
headers=self._requestHeaders(), headers=self._requestHeaders(),
key='volumes', key='volumes',
@ -356,7 +358,7 @@ class Client: # pylint: disable=too-many-public-methods
self, volumeId: typing.Optional[typing.Dict[str, typing.Any]] = None self, volumeId: typing.Optional[typing.Dict[str, typing.Any]] = None
) -> typing.Iterable[typing.Any]: ) -> typing.Iterable[typing.Any]:
for s in getRecurringUrlJson( for s in getRecurringUrlJson(
self._getEndpointFor('volumev2') + '/snapshots', self._getEndpointFor(self._volume) + '/snapshots',
self._session, self._session,
headers=self._requestHeaders(), headers=self._requestHeaders(),
key='snapshots', key='snapshots',
@ -474,7 +476,7 @@ class Client: # pylint: disable=too-many-public-methods
@authProjectRequired @authProjectRequired
def getVolume(self, volumeId: str) -> typing.Dict[str, typing.Any]: def getVolume(self, volumeId: str) -> typing.Dict[str, typing.Any]:
r = self._session.get( r = self._session.get(
self._getEndpointFor('volumev2') self._getEndpointFor(self._volume)
+ '/volumes/{volume_id}'.format(volume_id=volumeId), + '/volumes/{volume_id}'.format(volume_id=volumeId),
headers=self._requestHeaders(), headers=self._requestHeaders(),
verify=VERIFY_SSL, verify=VERIFY_SSL,
@ -492,7 +494,7 @@ class Client: # pylint: disable=too-many-public-methods
creating, available, deleting, error, error_deleting creating, available, deleting, error, error_deleting
""" """
r = self._session.get( r = self._session.get(
self._getEndpointFor('volumev2') self._getEndpointFor(self._volume)
+ '/snapshots/{snapshot_id}'.format(snapshot_id=snapshotId), + '/snapshots/{snapshot_id}'.format(snapshot_id=snapshotId),
headers=self._requestHeaders(), headers=self._requestHeaders(),
verify=VERIFY_SSL, verify=VERIFY_SSL,
@ -518,7 +520,7 @@ class Client: # pylint: disable=too-many-public-methods
data['snapshot']['description'] = description data['snapshot']['description'] = description
r = self._session.put( r = self._session.put(
self._getEndpointFor('volumev2') self._getEndpointFor(self._volume)
+ '/snapshots/{snapshot_id}'.format(snapshot_id=snapshotId), + '/snapshots/{snapshot_id}'.format(snapshot_id=snapshotId),
data=json.dumps(data), data=json.dumps(data),
headers=self._requestHeaders(), headers=self._requestHeaders(),
@ -547,7 +549,7 @@ class Client: # pylint: disable=too-many-public-methods
# First, ensure volume is in state "available" # First, ensure volume is in state "available"
r = self._session.post( r = self._session.post(
self._getEndpointFor('volumev2') + '/snapshots', self._getEndpointFor(self._volume) + '/snapshots',
data=json.dumps(data), data=json.dumps(data),
headers=self._requestHeaders(), headers=self._requestHeaders(),
verify=VERIFY_SSL, verify=VERIFY_SSL,
@ -575,7 +577,7 @@ class Client: # pylint: disable=too-many-public-methods
} }
r = self._session.post( r = self._session.post(
self._getEndpointFor('volumev2') + '/volumes', self._getEndpointFor(self._volume) + '/volumes',
data=json.dumps(data), data=json.dumps(data),
headers=self._requestHeaders(), headers=self._requestHeaders(),
verify=VERIFY_SSL, verify=VERIFY_SSL,
@ -662,7 +664,7 @@ class Client: # pylint: disable=too-many-public-methods
@authProjectRequired @authProjectRequired
def deleteSnapshot(self, snapshotId: str) -> None: def deleteSnapshot(self, snapshotId: str) -> None:
r = self._session.delete( r = self._session.delete(
self._getEndpointFor('volumev2') self._getEndpointFor(self._volume)
+ '/snapshots/{snapshot_id}'.format(snapshot_id=snapshotId), + '/snapshots/{snapshot_id}'.format(snapshot_id=snapshotId),
headers=self._requestHeaders(), headers=self._requestHeaders(),
verify=VERIFY_SSL, verify=VERIFY_SSL,