Skip to content

Commit

Permalink
Fix issue #65
Browse files Browse the repository at this point in the history
* Add SessionCSRFAuthentication
  • Loading branch information
apragacz committed Oct 8, 2019
1 parent 9e0fbec commit ece77cd
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 15 deletions.
9 changes: 9 additions & 0 deletions rest_registration/api/authentication.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from rest_framework.authentication import SessionAuthentication


class SessionCSRFAuthentication(SessionAuthentication):

def authenticate(self, request):
user = getattr(request._request, 'user', None) # noqa: E501 pylint: disable=protected-access
self.enforce_csrf(request)
return (user, None)
43 changes: 28 additions & 15 deletions tests/api/test_login.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from django.test.utils import modify_settings, override_settings
from rest_framework import status
from rest_framework.authtoken.models import Token
from rest_framework.test import force_authenticate
from rest_framework.test import APIRequestFactory, force_authenticate

from ..helpers import override_rest_framework_settings_dict
from .base import APIViewTestCase


Expand Down Expand Up @@ -53,6 +54,22 @@ def test_invalid(self):
response = self.view_func(request)
self.assert_invalid_response(response, status.HTTP_400_BAD_REQUEST)

@override_rest_framework_settings_dict({
'DEFAULT_AUTHENTICATION_CLASSES': [
'rest_registration.api.authentication.SessionCSRFAuthentication',
],
})
def test_csrf(self):
factory = APIRequestFactory(enforce_csrf_checks=True)
request = factory.post("/login", {
'login': self.user.username,
'password': self.password,
})
self.add_session_to_request(request)
response = self.view_func(request)
self.assert_invalid_response(
response, status.HTTP_403_FORBIDDEN)


class LogoutViewTestCase(BaseLoginTestCase):
VIEW_NAME = 'logout'
Expand All @@ -77,13 +94,11 @@ def test_revoke_token_success(self):
'remove': 'django.contrib.sessions.middleware.SessionMiddleware',
}
)
@override_settings(
REST_FRAMEWORK={
'DEFAULT_AUTHENTICATION_CLASSES': (
'rest_framework.authentication.TokenAuthentication',
),
},
)
@override_rest_framework_settings_dict({
'DEFAULT_AUTHENTICATION_CLASSES': (
'rest_framework.authentication.TokenAuthentication',
),
})
def test_revoke_token_success_without_session(self):
self._test_revoke_token_success(add_session=False)

Expand Down Expand Up @@ -112,13 +127,11 @@ def test_revoke_nonexistent_token_failure(self):
'remove': 'django.contrib.sessions.middleware.SessionMiddleware',
}
)
@override_settings(
REST_FRAMEWORK={
'DEFAULT_AUTHENTICATION_CLASSES': (
'rest_framework.authentication.TokenAuthentication',
),
},
)
@override_rest_framework_settings_dict({
'DEFAULT_AUTHENTICATION_CLASSES': (
'rest_framework.authentication.TokenAuthentication',
),
})
def test_revoke_nonexistent_token_failure_without_session(self):
self._test_revoke_nonexistent_token_failure(add_session=False)

Expand Down
27 changes: 27 additions & 0 deletions tests/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import contextlib

from django.test.utils import override_settings
from rest_framework.settings import api_settings
from rest_framework.views import APIView


@contextlib.contextmanager
def override_rest_framework_settings_dict(settings_dict):
try:
with override_settings(REST_FRAMEWORK=settings_dict):
_update_api_view_class_attrs()
yield
finally:
_update_api_view_class_attrs()


def _update_api_view_class_attrs():
attrs = [
'renderer_classes', 'parser_classes', 'authentication_classes',
'throttle_classes', 'permission_classes', 'content_negotiation_class',
'metadata_class', 'versioning_class',
]
for attr in attrs:
attr_upper = attr.upper()
api_settings_key = 'DEFAULT_{attr_upper}'.format(attr_upper=attr_upper)
setattr(APIView, attr, getattr(api_settings, api_settings_key))

0 comments on commit ece77cd

Please sign in to comment.