1
0
mirror of https://github.com/ansible/awx.git synced 2024-11-01 08:21:15 +03:00

Merge pull request #494 from chrismeyersfsu/feature-socket_tests

adds socket tests
This commit is contained in:
Chris Meyers 2015-11-06 14:39:03 -05:00
commit c52c7c2f9f
3 changed files with 157 additions and 31 deletions

View File

@ -46,43 +46,52 @@ class SocketSession(object):
return bool(not auth_token.is_expired())
class SocketSessionManager(object):
socket_sessions = []
socket_session_token_key_map = {}
@classmethod
def _prune(cls):
if len(cls.socket_sessions) > 1000:
session = cls.socket_session[0]
del cls.socket_session_token_key_map[session.token_key]
cls.sessions = cls.socket_sessions[1:]
def __init__(self):
self.SESSIONS_MAX = 1000
self.socket_sessions = []
self.socket_session_token_key_map = {}
def _prune(self):
if len(self.socket_sessions) > self.SESSIONS_MAX:
session = self.socket_sessions[0]
entries = self.socket_session_token_key_map[session.token_key]
del entries[session.session_id]
if len(entries) == 0:
del self.socket_session_token_key_map[session.token_key]
self.socket_sessions.pop(0)
'''
Returns an dict of sessions <session_id, session>
'''
@classmethod
def lookup(cls, token_key=None):
def lookup(self, token_key=None):
if not token_key:
raise ValueError("token_key required")
return cls.socket_session_token_key_map.get(token_key, None)
return self.socket_session_token_key_map.get(token_key, None)
@classmethod
def add_session(cls, session):
cls.socket_sessions.append(session)
entries = cls.socket_session_token_key_map.get(session.token_key, None)
def add_session(self, session):
self.socket_sessions.append(session)
entries = self.socket_session_token_key_map.get(session.token_key, None)
if not entries:
entries = {}
cls.socket_session_token_key_map[session.token_key] = entries
self.socket_session_token_key_map[session.token_key] = entries
entries[session.session_id] = session
cls._prune()
self._prune()
return session
class SocketController(object):
server = None
@classmethod
def broadcast_packet(cls, packet):
def __init__(self, SocketSessionManager):
self.server = None
self.SocketSessionManager = SocketSessionManager
def add_session(self, session):
return self.SocketSessionManager.add_session(session)
def broadcast_packet(self, packet):
# Broadcast message to everyone at endpoint
# Loop over the 'raw' list of sockets (don't trust our list)
for session_id, socket in list(cls.server.sockets.iteritems()):
for session_id, socket in list(self.server.sockets.iteritems()):
socket_session = socket.session.get('socket_session', None)
if socket_session and socket_session.is_valid():
try:
@ -91,11 +100,10 @@ class SocketController(object):
logger.error("Error sending client packet to %s: %s" % (str(session_id), str(packet)))
logger.error("Error was: " + str(e))
@classmethod
def send_packet(cls, packet, token_key):
def send_packet(self, packet, token_key):
if not token_key:
raise ValueError("token_key is required")
socket_sessions = SocketSessionManager.lookup(token_key=token_key)
socket_sessions = self.SocketSessionManager.lookup(token_key=token_key)
# We may not find the socket_session if the user disconnected
# (it's actually more compliciated than that because of our prune logic)
if not socket_sessions:
@ -112,11 +120,12 @@ class SocketController(object):
logger.error("Error sending client packet to %s: %s" % (str(socket_session.session_id), str(packet)))
logger.error("Error was: " + str(e))
@classmethod
def set_server(cls, server):
cls.server = server
def set_server(self, server):
self.server = server
return server
socketController = SocketController(SocketSessionManager())
#
# Socket session is attached to self.session['socket_session']
# self.session and self.socket.session point to the same dict
@ -140,7 +149,7 @@ class TowerBaseNamespace(BaseNamespace):
socket_session = SocketSession(self.socket.sessid, request_token, self.socket)
if socket_session.is_db_token_valid():
self.session['socket_session'] = socket_session
SocketSessionManager.add_session(socket_session)
socketController.add_session(socket_session)
else:
socket_session.invalidate()
@ -240,9 +249,9 @@ def notification_handler(server):
if 'token_key' in message:
# Best practice not to send the token over the socket
SocketController.send_packet(packet, message.pop('token_key'))
socketController.send_packet(packet, message.pop('token_key'))
else:
SocketController.broadcast_packet(packet)
socketController.broadcast_packet(packet)
class Command(NoArgsCommand):
'''
@ -270,7 +279,7 @@ class Command(NoArgsCommand):
logger.info('Listening on port http://0.0.0.0:' + str(socketio_listen_port))
server = SocketIOServer(('0.0.0.0', socketio_listen_port), TowerSocket(), resource='socket.io')
SocketController.set_server(server)
socketController.set_server(server)
handler_thread = Thread(target=notification_handler, args=(server,))
handler_thread.daemon = True
handler_thread.start()

View File

@ -8,4 +8,5 @@ from .commands_monolithic import * # noqa
from .cleanup_facts import * # noqa
from .age_deleted import * # noqa
from .remove_instance import * # noqa
from .run_socketio_service import * # noqa

View File

@ -0,0 +1,116 @@
# Copyright (c) 2015 Ansible, Inc.
# All Rights Reserved
# Python
from mock import MagicMock, Mock
# Django
from django.test import SimpleTestCase
# AWX
from awx.fact.models.fact import * # noqa
from awx.main.management.commands.run_socketio_service import SocketSessionManager, SocketSession, SocketController
__all__ = ['SocketSessionManagerUnitTest', 'SocketControllerUnitTest',]
class WeakRefable():
pass
class SocketSessionManagerUnitTest(SimpleTestCase):
def setUp(self):
self.session_manager = SocketSessionManager()
super(SocketSessionManagerUnitTest, self).setUp()
def create_sessions(self, count, token_key=None):
self.sessions = []
self.count = count
for i in range(0, count):
self.sessions.append(SocketSession(i, token_key or i, WeakRefable()))
self.session_manager.add_session(self.sessions[i])
def test_multiple_session_diff_token(self):
self.create_sessions(10)
for s in self.sessions:
self.assertIn(s.token_key, self.session_manager.socket_session_token_key_map)
self.assertEqual(s, self.session_manager.socket_session_token_key_map[s.token_key][s.session_id])
def test_multiple_session_same_token(self):
self.create_sessions(10, token_key='foo')
sessions_dict = self.session_manager.lookup("foo")
self.assertEqual(len(sessions_dict), 10)
for s in self.sessions:
self.assertIn(s.session_id, sessions_dict)
self.assertEqual(s, sessions_dict[s.session_id])
def test_prune_sessions_max(self):
self.create_sessions(self.session_manager.SESSIONS_MAX + 10)
self.assertEqual(len(self.session_manager.socket_sessions), self.session_manager.SESSIONS_MAX)
class SocketControllerUnitTest(SimpleTestCase):
def setUp(self):
self.socket_controller = SocketController(SocketSessionManager())
server = Mock()
self.socket_controller.set_server(server)
super(SocketControllerUnitTest, self).setUp()
def create_clients(self, count, token_key=None):
self.sessions = []
self.sockets =[]
self.count = count
self.sockets_dict = {}
for i in range(0, count):
if isinstance(token_key, list):
token_key_actual = token_key[i]
else:
token_key_actual = token_key or i
socket = MagicMock(session=dict())
socket_session = SocketSession(i, token_key_actual, socket)
self.sockets.append(socket)
self.sessions.append(socket_session)
self.sockets_dict[i] = socket
self.socket_controller.add_session(socket_session)
socket.session['socket_session'] = socket_session
socket.send_packet = Mock()
self.socket_controller.server.sockets = self.sockets_dict
def test_broadcast_packet(self):
self.create_clients(10)
packet = {
"hello": "world"
}
self.socket_controller.broadcast_packet(packet)
for s in self.sockets:
s.send_packet.assert_called_with(packet)
def test_send_packet(self):
self.create_clients(5, token_key=[0, 1, 2, 3, 4])
packet = {
"hello": "world"
}
self.socket_controller.send_packet(packet, 2)
self.assertEqual(0, len(self.sockets[0].send_packet.mock_calls))
self.assertEqual(0, len(self.sockets[1].send_packet.mock_calls))
self.sockets[2].send_packet.assert_called_once_with(packet)
self.assertEqual(0, len(self.sockets[3].send_packet.mock_calls))
self.assertEqual(0, len(self.sockets[4].send_packet.mock_calls))
def test_send_packet_multiple_sessions_one_token(self):
self.create_clients(5, token_key=[0, 1, 1, 1, 2])
packet = {
"hello": "world"
}
self.socket_controller.send_packet(packet, 1)
self.assertEqual(0, len(self.sockets[0].send_packet.mock_calls))
self.sockets[1].send_packet.assert_called_once_with(packet)
self.sockets[2].send_packet.assert_called_once_with(packet)
self.sockets[3].send_packet.assert_called_once_with(packet)
self.assertEqual(0, len(self.sockets[4].send_packet.mock_calls))