mirror of
https://github.com/dkmstr/openuds.git
synced 2025-01-08 21:18:00 +03:00
Improved connection tunnel timeout
This commit is contained in:
parent
8feef1d3f9
commit
d72723d6f2
@ -62,7 +62,6 @@ class ForwardServer(socketserver.ThreadingTCPServer):
|
||||
ticket: str
|
||||
stop_flag: threading.Event
|
||||
can_stop: bool
|
||||
timeout: int
|
||||
timer: typing.Optional[threading.Timer]
|
||||
check_certificate: bool
|
||||
keep_listening: bool
|
||||
@ -78,15 +77,17 @@ class ForwardServer(socketserver.ThreadingTCPServer):
|
||||
check_certificate: bool = True,
|
||||
keep_listening: bool = False,
|
||||
) -> None:
|
||||
local_port = local_port or random.randrange(33000, 53000)
|
||||
# Negative values for timeout, means "accept always connections"
|
||||
# "but if no connection is stablished on timeout (positive)"
|
||||
# "stop the listener"
|
||||
# Note that this is for backwards compatibility, better use "keep_listening"
|
||||
if timeout < 0:
|
||||
keep_listening = True
|
||||
timeout = abs(timeout)
|
||||
|
||||
super().__init__(server_address=(LISTEN_ADDRESS, local_port), RequestHandlerClass=Handler)
|
||||
self.remote = remote
|
||||
self.ticket = ticket
|
||||
# Negative values for timeout, means "accept always connections"
|
||||
# "but if no connection is stablished on timeout (positive)"
|
||||
# "stop the listener"
|
||||
self.timeout = int(time.time()) + timeout if timeout > 0 else 0
|
||||
self.check_certificate = check_certificate
|
||||
self.keep_listening = keep_listening
|
||||
self.stop_flag = threading.Event() # False initial
|
||||
@ -95,7 +96,7 @@ class ForwardServer(socketserver.ThreadingTCPServer):
|
||||
self.status = TUNNEL_LISTENING
|
||||
self.can_stop = False
|
||||
|
||||
timeout = abs(timeout) or 60
|
||||
timeout = timeout or 60
|
||||
self.timer = threading.Timer(abs(timeout), ForwardServer.__checkStarted, args=(self,))
|
||||
self.timer.start()
|
||||
|
||||
@ -154,10 +155,14 @@ class ForwardServer(socketserver.ThreadingTCPServer):
|
||||
@property
|
||||
def stoppable(self) -> bool:
|
||||
logger.debug('Is stoppable: %s', self.can_stop)
|
||||
return self.can_stop or (self.timeout != 0 and int(time.time()) > self.timeout)
|
||||
return self.can_stop
|
||||
|
||||
@staticmethod
|
||||
def __checkStarted(fs: 'ForwardServer') -> None:
|
||||
# As soon as the timer is fired, the server can be stopped
|
||||
# This means that:
|
||||
# * If not connections are stablished, the server will be stopped
|
||||
# * If no "keep_listening" is set, the server will not allow any new connections
|
||||
logger.debug('New connection limit reached')
|
||||
fs.timer = None
|
||||
fs.can_stop = True
|
||||
@ -232,10 +237,9 @@ class Handler(socketserver.BaseRequestHandler):
|
||||
|
||||
def _run(server: ForwardServer) -> None:
|
||||
logger.debug(
|
||||
'Starting forwarder: %s -> %s, timeout: %d',
|
||||
'Starting forwarder: %s -> %s',
|
||||
server.server_address,
|
||||
server.remote,
|
||||
server.timeout,
|
||||
)
|
||||
server.serve_forever()
|
||||
logger.debug('Stoped forwarder %s -> %s', server.server_address, server.remote)
|
||||
@ -248,7 +252,6 @@ def forward(
|
||||
local_port: int = 0,
|
||||
check_certificate=True,
|
||||
keep_listening=True,
|
||||
initial_payload: PayLoadType = None,
|
||||
) -> ForwardServer:
|
||||
fs = ForwardServer(
|
||||
remote=remote,
|
||||
@ -280,7 +283,8 @@ if __name__ == "__main__":
|
||||
fs = forward(
|
||||
('172.27.0.1', 7777),
|
||||
ticket,
|
||||
local_port=49999,
|
||||
local_port=0,
|
||||
timeout=-20,
|
||||
check_certificate=False,
|
||||
)
|
||||
print('Listening on port', fs.server_address)
|
||||
|
Loading…
Reference in New Issue
Block a user