diff --git a/awx/main/consumers.py b/awx/main/consumers.py index ff55507939..39020099d1 100644 --- a/awx/main/consumers.py +++ b/awx/main/consumers.py @@ -1,16 +1,14 @@ import json import logging -import urllib from channels import Group, channel_layers -from channels.sessions import channel_session -from channels.handler import AsgiRequest +from channels.sessions import enforce_ordering, channel_session, channel_and_http_session from django.conf import settings from django.core.serializers.json import DjangoJSONEncoder from django.contrib.auth.models import User -from awx.main.models.organization import AuthToken +from django.contrib.sessions.models import Session logger = logging.getLogger('awx.main.consumers') @@ -22,24 +20,21 @@ def discard_groups(message): Group(group).discard(message.reply_channel) -@channel_session +@channel_and_http_session def ws_connect(message): - connect_text = {'accept':False, 'user':None} + if message.http_session.session_key is None: + raise ValueError('No valid session key to get auth from') - message.content['method'] = 'FAKE' - request = AsgiRequest(message) - token = request.COOKIES.get('token', None) - if token is not None: - token = urllib.unquote(token).strip('"') - try: - auth_token = AuthToken.objects.get(key=token) - if auth_token.in_valid_tokens: - message.channel_session['user_id'] = auth_token.user_id - connect_text['accept'] = True - connect_text['user'] = auth_token.user_id - except AuthToken.DoesNotExist: - logger.error("auth_token provided was invalid.") - message.reply_channel.send({"text": json.dumps(connect_text)}) + session = Session.objects.get(session_key=message.http_session.session_key) + session_data = session.get_decoded() + + try: + user = User.objects.get(pk=session_data['_auth_user_id']) + except User.DoesNotExist: + raise ValueError('No valid user for the session key') + + message.channel_session['user_id'] = user.pk + message.reply_channel.send({"text": json.dumps({'accept': True, 'user': user.pk})}) @channel_session @@ -47,6 +42,7 @@ def ws_disconnect(message): discard_groups(message) +@enforce_ordering @channel_session def ws_receive(message): from awx.main.access import consumer_access diff --git a/awx/settings/defaults.py b/awx/settings/defaults.py index d076e234ea..ec5f18ba43 100644 --- a/awx/settings/defaults.py +++ b/awx/settings/defaults.py @@ -189,6 +189,9 @@ JOB_EVENT_MAX_QUEUE_SIZE = 10000 # Disallow sending session cookies over insecure connections SESSION_COOKIE_SECURE = True +# Do not allow non-browser clients to read the CSRF cookie. +CSRF_COOKIE_HTTPONLY = True + # Disallow sending csrf cookies over insecure connections CSRF_COOKIE_SECURE = True diff --git a/awx/sso/views.py b/awx/sso/views.py index 80092a8040..84826a0bb0 100644 --- a/awx/sso/views.py +++ b/awx/sso/views.py @@ -60,7 +60,7 @@ class CompleteView(BaseRedirectView): logger.info(smart_text(u"User {} logged in".format(self.request.user.username))) request.session['auth_token_key'] = token.key token_key = urllib.quote('"%s"' % token.key) - response.set_cookie('token', token_key) + response.set_cookie('token', value=token_key, httponly=True) token_expires = token.expires.astimezone(utc).strftime('%Y-%m-%dT%H:%M:%S') token_expires = '%s.%03dZ' % (token_expires, token.expires.microsecond / 1000) token_expires = urllib.quote('"%s"' % token_expires)