Skip to content

Fix request-based filtering #161

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 4, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 20 additions & 24 deletions rest_framework_filters/backends.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@

from django.template import Template, TemplateDoesNotExist, loader
from rest_framework import compat
from contextlib import contextmanager
from django_filters.rest_framework import backends

from .filterset import FilterSet
Expand All @@ -9,30 +8,27 @@
class DjangoFilterBackend(backends.DjangoFilterBackend):
default_filter_set = FilterSet

def filter_queryset(self, request, queryset, view):
filter_class = self.get_filter_class(view, queryset)

if filter_class:
if hasattr(filter_class, 'get_subset'):
filter_class = filter_class.get_subset(request.query_params)
return filter_class(request.query_params, queryset=queryset).qs
@contextmanager
def patched_filter_class(self, request):
"""
Patch `get_filter_class()` to get the subset based on the request params
"""
original = self.get_filter_class

return queryset
def get_subset_class(view, queryset=None):
filter_class = original(view, queryset)

def to_html(self, request, queryset, view):
filter_class = self.get_filter_class(view, queryset)
if not filter_class:
return None
filter_instance = filter_class(request.query_params, queryset=queryset)
if filter_class and hasattr(filter_class, 'get_subset'):
filter_class = filter_class.get_subset(request.query_params)

# forces `form` evaluation before `qs` is called. This prevents an empty form from being cached.
filter_instance.form
return filter_class

try:
template = loader.get_template(self.template)
except TemplateDoesNotExist:
template = Template(backends.template_default)
self.get_filter_class = get_subset_class
yield
self.get_filter_class = original

return compat.template_render(template, context={
'filter': filter_instance
})
def filter_queryset(self, request, queryset, view):
# patching the behavior of `get_filter_class()` in this method allows
# us to avoid maintenance issues with code duplication.
with self.patched_filter_class(request):
return super(DjangoFilterBackend, self).filter_queryset(request, queryset, view)
2 changes: 1 addition & 1 deletion rest_framework_filters/filterset.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def expand_filters(self):
if isinstance(f, filters.RelatedFilter) and filter_name in related_data:
subset_data = related_data[filter_name]
subset_class = f.filterset.get_subset(subset_data)
filterset = subset_class(data=subset_data)
filterset = subset_class(data=subset_data, request=self.request)

# modify filter names to account for relationship
for related_name, related_f in six.iteritems(filterset.expand_filters()):
Expand Down
23 changes: 23 additions & 0 deletions tests/test_backends.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

from rest_framework.test import APITestCase, APIRequestFactory
from rest_framework_filters import FilterSet

from .testapp import models, views

Expand Down Expand Up @@ -58,3 +59,25 @@ class SimpleViewSet(views.FilterFieldsUserViewSet):
<button type="submit" class="btn btn-primary">Submit</button>
</form>
""")

def test_request_obj_is_passed(self):
"""
Ensure that the request object is passed from the backend to the filterset.
See: https://github.com/philipn/django-rest-framework-filters/issues/149
"""
class RequestCheck(FilterSet):
def __init__(self, *args, **kwargs):
super(RequestCheck, self).__init__(*args, **kwargs)
assert self.request is not None

class Meta:
model = models.User
fields = ['username']

class ViewSet(views.FilterFieldsUserViewSet):
filter_class = RequestCheck

view = ViewSet(action_map={})
backend = view.filter_backends[0]
request = view.initialize_request(factory.get('/'))
backend().filter_queryset(request, view.get_queryset(), view)
22 changes: 22 additions & 0 deletions tests/test_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,28 @@ class Meta:
msg = str(excinfo.exception)
self.assertEqual("Expected `.get_queryset()` to return a `QuerySet`, but got `None`.", msg)

def test_relatedfilter_request_is_passed(self):
class RequestCheck(FilterSet):
def __init__(self, *args, **kwargs):
super(RequestCheck, self).__init__(*args, **kwargs)
assert self.request is not None

class Meta:
model = User
fields = ['username']

class NoteFilter(FilterSet):
author = filters.RelatedFilter(RequestCheck, name='author')

class Meta:
model = Note
fields = []

GET = {'author__username': 'user2'}

# should pass
NoteFilter(GET, queryset=Note.objects.all(), request=object()).qs


class MiscTests(TestCase):
def test_multiwidget_incompatibility(self):
Expand Down