1
0
mirror of https://github.com/dkmstr/openuds.git synced 2024-12-22 13:34:04 +03:00

A better implementation for shared connections

This commit is contained in:
Adolfo Gómez García 2016-04-28 09:34:03 +02:00
parent 0fb7d5ed1b
commit 94842ce0ef

View File

@ -73,7 +73,6 @@ class Handler(SocketServer.BaseRequestHandler):
class ForwardThread(threading.Thread): class ForwardThread(threading.Thread):
status = 0 # Connecting status = 0 # Connecting
clientUseCounter = 0
def __init__(self, server, port, username, password, localPort, redirectHost, redirectPort, waitTime): def __init__(self, server, port, username, password, localPort, redirectHost, redirectPort, waitTime):
threading.Thread.__init__(self) threading.Thread.__init__(self)
@ -104,6 +103,7 @@ class ForwardThread(threading.Thread):
ft = ForwardThread(self.server, self.port, self.username, self.password, localPort, redirectHost, redirectPort. self.waitTime) ft = ForwardThread(self.server, self.port, self.username, self.password, localPort, redirectHost, redirectPort. self.waitTime)
ft.client = self.client ft.client = self.client
self.client.useCount += 1 # One more using this client
ft.start() ft.start()
while ft.status == 0: while ft.status == 0:
@ -122,6 +122,7 @@ class ForwardThread(threading.Thread):
def run(self): def run(self):
if self.client is None: if self.client is None:
self.client = paramiko.SSHClient() self.client = paramiko.SSHClient()
self.client.useCount = 1 # Custom added variable, to keep track on when to close tunnel
self.client.load_system_host_keys() self.client.load_system_host_keys()
self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
@ -134,6 +135,8 @@ class ForwardThread(threading.Thread):
self.status = 2 # Error self.status = 2 # Error
return return
self.clientUseCount += 1
class SubHandler(Handler): class SubHandler(Handler):
chain_host = self.redirectHost chain_host = self.redirectHost
chain_port = self.redirectPort chain_port = self.redirectPort
@ -147,8 +150,6 @@ class ForwardThread(threading.Thread):
self.status = 1 # Ok, listening self.status = 1 # Ok, listening
ForwardThread.clientUseCounter += 1
self.fs = ForwardServer(('', self.localPort), SubHandler) self.fs = ForwardServer(('', self.localPort), SubHandler)
self.fs.serve_forever() self.fs.serve_forever()
@ -161,9 +162,10 @@ class ForwardThread(threading.Thread):
self.fs.shutdown() self.fs.shutdown()
if self.client is not None: if self.client is not None:
ForwardThread.clientUseCounter -= 1 self.client.useCount -= 1
if ForwardThread.clientUseCounter == 0: if self.client.useCount == 0:
self.client.close() self.client.close()
self.client = None # Clean up
except Exception: except Exception:
logger.exception('Exception stopping') logger.exception('Exception stopping')
pass pass