From ece77cd6bb70c1214dd1d5d7b32d287a30d729eb Mon Sep 17 00:00:00 2001 From: Andrzej Pragacz Date: Sun, 25 Aug 2019 17:49:05 +0200 Subject: [PATCH] Fix issue #65 * Add SessionCSRFAuthentication --- rest_registration/api/authentication.py | 9 ++++++ tests/api/test_login.py | 43 ++++++++++++++++--------- tests/helpers.py | 27 ++++++++++++++++ 3 files changed, 64 insertions(+), 15 deletions(-) create mode 100644 rest_registration/api/authentication.py create mode 100644 tests/helpers.py diff --git a/rest_registration/api/authentication.py b/rest_registration/api/authentication.py new file mode 100644 index 0000000..e36a5ea --- /dev/null +++ b/rest_registration/api/authentication.py @@ -0,0 +1,9 @@ +from rest_framework.authentication import SessionAuthentication + + +class SessionCSRFAuthentication(SessionAuthentication): + + def authenticate(self, request): + user = getattr(request._request, 'user', None) # noqa: E501 pylint: disable=protected-access + self.enforce_csrf(request) + return (user, None) diff --git a/tests/api/test_login.py b/tests/api/test_login.py index 532941f..443a77f 100644 --- a/tests/api/test_login.py +++ b/tests/api/test_login.py @@ -1,8 +1,9 @@ from django.test.utils import modify_settings, override_settings from rest_framework import status from rest_framework.authtoken.models import Token -from rest_framework.test import force_authenticate +from rest_framework.test import APIRequestFactory, force_authenticate +from ..helpers import override_rest_framework_settings_dict from .base import APIViewTestCase @@ -53,6 +54,22 @@ def test_invalid(self): response = self.view_func(request) self.assert_invalid_response(response, status.HTTP_400_BAD_REQUEST) + @override_rest_framework_settings_dict({ + 'DEFAULT_AUTHENTICATION_CLASSES': [ + 'rest_registration.api.authentication.SessionCSRFAuthentication', + ], + }) + def test_csrf(self): + factory = APIRequestFactory(enforce_csrf_checks=True) + request = factory.post("/login", { + 'login': self.user.username, + 'password': self.password, + }) + self.add_session_to_request(request) + response = self.view_func(request) + self.assert_invalid_response( + response, status.HTTP_403_FORBIDDEN) + class LogoutViewTestCase(BaseLoginTestCase): VIEW_NAME = 'logout' @@ -77,13 +94,11 @@ def test_revoke_token_success(self): 'remove': 'django.contrib.sessions.middleware.SessionMiddleware', } ) - @override_settings( - REST_FRAMEWORK={ - 'DEFAULT_AUTHENTICATION_CLASSES': ( - 'rest_framework.authentication.TokenAuthentication', - ), - }, - ) + @override_rest_framework_settings_dict({ + 'DEFAULT_AUTHENTICATION_CLASSES': ( + 'rest_framework.authentication.TokenAuthentication', + ), + }) def test_revoke_token_success_without_session(self): self._test_revoke_token_success(add_session=False) @@ -112,13 +127,11 @@ def test_revoke_nonexistent_token_failure(self): 'remove': 'django.contrib.sessions.middleware.SessionMiddleware', } ) - @override_settings( - REST_FRAMEWORK={ - 'DEFAULT_AUTHENTICATION_CLASSES': ( - 'rest_framework.authentication.TokenAuthentication', - ), - }, - ) + @override_rest_framework_settings_dict({ + 'DEFAULT_AUTHENTICATION_CLASSES': ( + 'rest_framework.authentication.TokenAuthentication', + ), + }) def test_revoke_nonexistent_token_failure_without_session(self): self._test_revoke_nonexistent_token_failure(add_session=False) diff --git a/tests/helpers.py b/tests/helpers.py new file mode 100644 index 0000000..bcabb72 --- /dev/null +++ b/tests/helpers.py @@ -0,0 +1,27 @@ +import contextlib + +from django.test.utils import override_settings +from rest_framework.settings import api_settings +from rest_framework.views import APIView + + +@contextlib.contextmanager +def override_rest_framework_settings_dict(settings_dict): + try: + with override_settings(REST_FRAMEWORK=settings_dict): + _update_api_view_class_attrs() + yield + finally: + _update_api_view_class_attrs() + + +def _update_api_view_class_attrs(): + attrs = [ + 'renderer_classes', 'parser_classes', 'authentication_classes', + 'throttle_classes', 'permission_classes', 'content_negotiation_class', + 'metadata_class', 'versioning_class', + ] + for attr in attrs: + attr_upper = attr.upper() + api_settings_key = 'DEFAULT_{attr_upper}'.format(attr_upper=attr_upper) + setattr(APIView, attr, getattr(api_settings, api_settings_key))