mirror of
https://github.com/ansible/awx.git
synced 2024-11-01 08:21:15 +03:00
Merge pull request #494 from chrismeyersfsu/feature-socket_tests
adds socket tests
This commit is contained in:
commit
c52c7c2f9f
@ -46,43 +46,52 @@ class SocketSession(object):
|
||||
return bool(not auth_token.is_expired())
|
||||
|
||||
class SocketSessionManager(object):
|
||||
socket_sessions = []
|
||||
socket_session_token_key_map = {}
|
||||
|
||||
@classmethod
|
||||
def _prune(cls):
|
||||
if len(cls.socket_sessions) > 1000:
|
||||
session = cls.socket_session[0]
|
||||
del cls.socket_session_token_key_map[session.token_key]
|
||||
cls.sessions = cls.socket_sessions[1:]
|
||||
def __init__(self):
|
||||
self.SESSIONS_MAX = 1000
|
||||
self.socket_sessions = []
|
||||
self.socket_session_token_key_map = {}
|
||||
|
||||
def _prune(self):
|
||||
if len(self.socket_sessions) > self.SESSIONS_MAX:
|
||||
session = self.socket_sessions[0]
|
||||
entries = self.socket_session_token_key_map[session.token_key]
|
||||
del entries[session.session_id]
|
||||
if len(entries) == 0:
|
||||
del self.socket_session_token_key_map[session.token_key]
|
||||
self.socket_sessions.pop(0)
|
||||
|
||||
'''
|
||||
Returns an dict of sessions <session_id, session>
|
||||
'''
|
||||
@classmethod
|
||||
def lookup(cls, token_key=None):
|
||||
def lookup(self, token_key=None):
|
||||
if not token_key:
|
||||
raise ValueError("token_key required")
|
||||
return cls.socket_session_token_key_map.get(token_key, None)
|
||||
return self.socket_session_token_key_map.get(token_key, None)
|
||||
|
||||
@classmethod
|
||||
def add_session(cls, session):
|
||||
cls.socket_sessions.append(session)
|
||||
entries = cls.socket_session_token_key_map.get(session.token_key, None)
|
||||
def add_session(self, session):
|
||||
self.socket_sessions.append(session)
|
||||
entries = self.socket_session_token_key_map.get(session.token_key, None)
|
||||
if not entries:
|
||||
entries = {}
|
||||
cls.socket_session_token_key_map[session.token_key] = entries
|
||||
self.socket_session_token_key_map[session.token_key] = entries
|
||||
entries[session.session_id] = session
|
||||
cls._prune()
|
||||
self._prune()
|
||||
return session
|
||||
|
||||
class SocketController(object):
|
||||
server = None
|
||||
|
||||
@classmethod
|
||||
def broadcast_packet(cls, packet):
|
||||
def __init__(self, SocketSessionManager):
|
||||
self.server = None
|
||||
self.SocketSessionManager = SocketSessionManager
|
||||
|
||||
def add_session(self, session):
|
||||
return self.SocketSessionManager.add_session(session)
|
||||
|
||||
def broadcast_packet(self, packet):
|
||||
# Broadcast message to everyone at endpoint
|
||||
# Loop over the 'raw' list of sockets (don't trust our list)
|
||||
for session_id, socket in list(cls.server.sockets.iteritems()):
|
||||
for session_id, socket in list(self.server.sockets.iteritems()):
|
||||
socket_session = socket.session.get('socket_session', None)
|
||||
if socket_session and socket_session.is_valid():
|
||||
try:
|
||||
@ -91,11 +100,10 @@ class SocketController(object):
|
||||
logger.error("Error sending client packet to %s: %s" % (str(session_id), str(packet)))
|
||||
logger.error("Error was: " + str(e))
|
||||
|
||||
@classmethod
|
||||
def send_packet(cls, packet, token_key):
|
||||
def send_packet(self, packet, token_key):
|
||||
if not token_key:
|
||||
raise ValueError("token_key is required")
|
||||
socket_sessions = SocketSessionManager.lookup(token_key=token_key)
|
||||
socket_sessions = self.SocketSessionManager.lookup(token_key=token_key)
|
||||
# We may not find the socket_session if the user disconnected
|
||||
# (it's actually more compliciated than that because of our prune logic)
|
||||
if not socket_sessions:
|
||||
@ -112,11 +120,12 @@ class SocketController(object):
|
||||
logger.error("Error sending client packet to %s: %s" % (str(socket_session.session_id), str(packet)))
|
||||
logger.error("Error was: " + str(e))
|
||||
|
||||
@classmethod
|
||||
def set_server(cls, server):
|
||||
cls.server = server
|
||||
def set_server(self, server):
|
||||
self.server = server
|
||||
return server
|
||||
|
||||
socketController = SocketController(SocketSessionManager())
|
||||
|
||||
#
|
||||
# Socket session is attached to self.session['socket_session']
|
||||
# self.session and self.socket.session point to the same dict
|
||||
@ -140,7 +149,7 @@ class TowerBaseNamespace(BaseNamespace):
|
||||
socket_session = SocketSession(self.socket.sessid, request_token, self.socket)
|
||||
if socket_session.is_db_token_valid():
|
||||
self.session['socket_session'] = socket_session
|
||||
SocketSessionManager.add_session(socket_session)
|
||||
socketController.add_session(socket_session)
|
||||
else:
|
||||
socket_session.invalidate()
|
||||
|
||||
@ -240,9 +249,9 @@ def notification_handler(server):
|
||||
|
||||
if 'token_key' in message:
|
||||
# Best practice not to send the token over the socket
|
||||
SocketController.send_packet(packet, message.pop('token_key'))
|
||||
socketController.send_packet(packet, message.pop('token_key'))
|
||||
else:
|
||||
SocketController.broadcast_packet(packet)
|
||||
socketController.broadcast_packet(packet)
|
||||
|
||||
class Command(NoArgsCommand):
|
||||
'''
|
||||
@ -270,7 +279,7 @@ class Command(NoArgsCommand):
|
||||
logger.info('Listening on port http://0.0.0.0:' + str(socketio_listen_port))
|
||||
server = SocketIOServer(('0.0.0.0', socketio_listen_port), TowerSocket(), resource='socket.io')
|
||||
|
||||
SocketController.set_server(server)
|
||||
socketController.set_server(server)
|
||||
handler_thread = Thread(target=notification_handler, args=(server,))
|
||||
handler_thread.daemon = True
|
||||
handler_thread.start()
|
||||
|
@ -8,4 +8,5 @@ from .commands_monolithic import * # noqa
|
||||
from .cleanup_facts import * # noqa
|
||||
from .age_deleted import * # noqa
|
||||
from .remove_instance import * # noqa
|
||||
from .run_socketio_service import * # noqa
|
||||
|
||||
|
116
awx/main/tests/commands/run_socketio_service.py
Normal file
116
awx/main/tests/commands/run_socketio_service.py
Normal file
@ -0,0 +1,116 @@
|
||||
# Copyright (c) 2015 Ansible, Inc.
|
||||
# All Rights Reserved
|
||||
|
||||
# Python
|
||||
from mock import MagicMock, Mock
|
||||
|
||||
# Django
|
||||
from django.test import SimpleTestCase
|
||||
|
||||
# AWX
|
||||
from awx.fact.models.fact import * # noqa
|
||||
from awx.main.management.commands.run_socketio_service import SocketSessionManager, SocketSession, SocketController
|
||||
|
||||
__all__ = ['SocketSessionManagerUnitTest', 'SocketControllerUnitTest',]
|
||||
|
||||
class WeakRefable():
|
||||
pass
|
||||
|
||||
class SocketSessionManagerUnitTest(SimpleTestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.session_manager = SocketSessionManager()
|
||||
super(SocketSessionManagerUnitTest, self).setUp()
|
||||
|
||||
def create_sessions(self, count, token_key=None):
|
||||
self.sessions = []
|
||||
self.count = count
|
||||
for i in range(0, count):
|
||||
self.sessions.append(SocketSession(i, token_key or i, WeakRefable()))
|
||||
self.session_manager.add_session(self.sessions[i])
|
||||
|
||||
def test_multiple_session_diff_token(self):
|
||||
self.create_sessions(10)
|
||||
|
||||
for s in self.sessions:
|
||||
self.assertIn(s.token_key, self.session_manager.socket_session_token_key_map)
|
||||
self.assertEqual(s, self.session_manager.socket_session_token_key_map[s.token_key][s.session_id])
|
||||
|
||||
|
||||
def test_multiple_session_same_token(self):
|
||||
self.create_sessions(10, token_key='foo')
|
||||
|
||||
sessions_dict = self.session_manager.lookup("foo")
|
||||
self.assertEqual(len(sessions_dict), 10)
|
||||
for s in self.sessions:
|
||||
self.assertIn(s.session_id, sessions_dict)
|
||||
self.assertEqual(s, sessions_dict[s.session_id])
|
||||
|
||||
def test_prune_sessions_max(self):
|
||||
self.create_sessions(self.session_manager.SESSIONS_MAX + 10)
|
||||
|
||||
self.assertEqual(len(self.session_manager.socket_sessions), self.session_manager.SESSIONS_MAX)
|
||||
|
||||
|
||||
class SocketControllerUnitTest(SimpleTestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.socket_controller = SocketController(SocketSessionManager())
|
||||
server = Mock()
|
||||
self.socket_controller.set_server(server)
|
||||
super(SocketControllerUnitTest, self).setUp()
|
||||
|
||||
def create_clients(self, count, token_key=None):
|
||||
self.sessions = []
|
||||
self.sockets =[]
|
||||
self.count = count
|
||||
self.sockets_dict = {}
|
||||
for i in range(0, count):
|
||||
if isinstance(token_key, list):
|
||||
token_key_actual = token_key[i]
|
||||
else:
|
||||
token_key_actual = token_key or i
|
||||
socket = MagicMock(session=dict())
|
||||
socket_session = SocketSession(i, token_key_actual, socket)
|
||||
self.sockets.append(socket)
|
||||
self.sessions.append(socket_session)
|
||||
self.sockets_dict[i] = socket
|
||||
self.socket_controller.add_session(socket_session)
|
||||
|
||||
socket.session['socket_session'] = socket_session
|
||||
socket.send_packet = Mock()
|
||||
self.socket_controller.server.sockets = self.sockets_dict
|
||||
|
||||
def test_broadcast_packet(self):
|
||||
self.create_clients(10)
|
||||
packet = {
|
||||
"hello": "world"
|
||||
}
|
||||
self.socket_controller.broadcast_packet(packet)
|
||||
for s in self.sockets:
|
||||
s.send_packet.assert_called_with(packet)
|
||||
|
||||
def test_send_packet(self):
|
||||
self.create_clients(5, token_key=[0, 1, 2, 3, 4])
|
||||
packet = {
|
||||
"hello": "world"
|
||||
}
|
||||
self.socket_controller.send_packet(packet, 2)
|
||||
self.assertEqual(0, len(self.sockets[0].send_packet.mock_calls))
|
||||
self.assertEqual(0, len(self.sockets[1].send_packet.mock_calls))
|
||||
self.sockets[2].send_packet.assert_called_once_with(packet)
|
||||
self.assertEqual(0, len(self.sockets[3].send_packet.mock_calls))
|
||||
self.assertEqual(0, len(self.sockets[4].send_packet.mock_calls))
|
||||
|
||||
def test_send_packet_multiple_sessions_one_token(self):
|
||||
self.create_clients(5, token_key=[0, 1, 1, 1, 2])
|
||||
packet = {
|
||||
"hello": "world"
|
||||
}
|
||||
self.socket_controller.send_packet(packet, 1)
|
||||
self.assertEqual(0, len(self.sockets[0].send_packet.mock_calls))
|
||||
self.sockets[1].send_packet.assert_called_once_with(packet)
|
||||
self.sockets[2].send_packet.assert_called_once_with(packet)
|
||||
self.sockets[3].send_packet.assert_called_once_with(packet)
|
||||
self.assertEqual(0, len(self.sockets[4].send_packet.mock_calls))
|
||||
|
Loading…
Reference in New Issue
Block a user