diff --git a/graphene_django/filter/tests/test_in_filter.py b/graphene_django/filter/tests/test_in_filter.py index 3d4034eaa..7bbee65a7 100644 --- a/graphene_django/filter/tests/test_in_filter.py +++ b/graphene_django/filter/tests/test_in_filter.py @@ -1,9 +1,11 @@ import pytest +from django_filters import FilterSet +from django_filters import rest_framework as filters from graphene import ObjectType, Schema from graphene.relay import Node from graphene_django import DjangoObjectType -from graphene_django.tests.models import Pet +from graphene_django.tests.models import Pet, Person from graphene_django.utils import DJANGO_FILTER_INSTALLED pytestmark = [] @@ -28,8 +30,27 @@ class Meta: } +class PersonFilterSet(FilterSet): + class Meta: + model = Person + fields = {} + + names = filters.BaseInFilter(method="filter_names") + + def filter_names(self, qs, name, value): + return qs.filter(name__in=value) + + +class PersonNode(DjangoObjectType): + class Meta: + model = Person + interfaces = (Node,) + filterset_class = PersonFilterSet + + class Query(ObjectType): pets = DjangoFilterConnectionField(PetNode) + people = DjangoFilterConnectionField(PersonNode) def test_string_in_filter(): @@ -61,6 +82,33 @@ def test_string_in_filter(): ] +def test_string_in_filter_with_filterset_class(): + """Test in filter on a string field with a custom filterset class.""" + Person.objects.create(name="John") + Person.objects.create(name="Michael") + Person.objects.create(name="Angela") + + schema = Schema(query=Query) + + query = """ + query { + people (names: ["John", "Michael"]) { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["people"]["edges"] == [ + {"node": {"name": "John"}}, + {"node": {"name": "Michael"}}, + ] + + def test_int_in_filter(): """ Test in filter on an integer field. diff --git a/graphene_django/filter/utils.py b/graphene_django/filter/utils.py index 0245f166f..5d3f4f4c8 100644 --- a/graphene_django/filter/utils.py +++ b/graphene_django/filter/utils.py @@ -17,6 +17,7 @@ def get_filtering_args_from_filterset(filterset_class, type): model = filterset_class._meta.model for name, filter_field in filterset_class.base_filters.items(): form_field = None + filter_type = filter_field.lookup_expr if name in filterset_class.declared_filters: # Get the filter field from the explicitly declared filter @@ -25,7 +26,6 @@ def get_filtering_args_from_filterset(filterset_class, type): else: # Get the filter field with no explicit type declaration model_field = get_model_field(model, filter_field.field_name) - filter_type = filter_field.lookup_expr if filter_type != "isnull" and hasattr(model_field, "formfield"): form_field = model_field.formfield( required=filter_field.extra.get("required", False) @@ -38,14 +38,14 @@ def get_filtering_args_from_filterset(filterset_class, type): field = convert_form_field(form_field) - if filter_type in ["in", "range"]: - # Replace CSV filters (`in`, `range`) argument type to be a list of the same type as the field. - # See comments in `replace_csv_filters` method for more details. - field = List(field.get_type()) + if filter_type in ["in", "range"]: + # Replace CSV filters (`in`, `range`) argument type to be a list of + # the same type as the field. See comments in + # `replace_csv_filters` method for more details. + field = List(field.get_type()) field_type = field.Argument() field_type.description = str(filter_field.label) if filter_field.label else None - args[name] = field_type return args @@ -78,10 +78,7 @@ def replace_csv_filters(filterset_class): """ for name, filter_field in list(filterset_class.base_filters.items()): filter_type = filter_field.lookup_expr - if ( - filter_type in ["in", "range"] - and name not in filterset_class.declared_filters - ): + if filter_type in ["in", "range"]: assert isinstance(filter_field, BaseCSVFilter) filterset_class.base_filters[name] = Filter( field_name=filter_field.field_name, diff --git a/graphene_django/tests/models.py b/graphene_django/tests/models.py index 708fe579a..180acc527 100644 --- a/graphene_django/tests/models.py +++ b/graphene_django/tests/models.py @@ -6,6 +6,10 @@ CHOICES = ((1, "this"), (2, _("that"))) +class Person(models.Model): + name = models.CharField(max_length=30) + + class Pet(models.Model): name = models.CharField(max_length=30) age = models.PositiveIntegerField() diff --git a/graphene_django/tests/test_utils.py b/graphene_django/tests/test_utils.py index 766032e42..d895f46bf 100644 --- a/graphene_django/tests/test_utils.py +++ b/graphene_django/tests/test_utils.py @@ -51,6 +51,7 @@ def runTest(self): pass tc = TestClass() + tc._pre_setup() tc.setUpClass() tc.query("query { }", operation_name="QueryName") body = json.loads(post_mock.call_args.args[1]) diff --git a/graphene_django/utils/testing.py b/graphene_django/utils/testing.py index 22efd8b1a..c783f71bf 100644 --- a/graphene_django/utils/testing.py +++ b/graphene_django/utils/testing.py @@ -1,6 +1,7 @@ import json +import warnings -from django.test import TestCase, Client +from django.test import Client, TestCase DEFAULT_GRAPHQL_URL = "/graphql/" @@ -68,12 +69,6 @@ class GraphQLTestCase(TestCase): # URL to graphql endpoint GRAPHQL_URL = DEFAULT_GRAPHQL_URL - @classmethod - def setUpClass(cls): - super(GraphQLTestCase, cls).setUpClass() - - cls._client = Client() - def query( self, query, operation_name=None, input_data=None, variables=None, headers=None ): @@ -101,10 +96,19 @@ def query( input_data=input_data, variables=variables, headers=headers, - client=self._client, + client=self.client, graphql_url=self.GRAPHQL_URL, ) + @property + def _client(self): + warnings.warn( + "Using `_client` is deprecated in favour of `client`.", + PendingDeprecationWarning, + stacklevel=2, + ) + return self.client + def assertResponseNoErrors(self, resp, msg=None): """ Assert that the call went through correctly. 200 means the syntax is ok, if there are no `errors`, diff --git a/graphene_django/utils/tests/test_testing.py b/graphene_django/utils/tests/test_testing.py new file mode 100644 index 000000000..df7832130 --- /dev/null +++ b/graphene_django/utils/tests/test_testing.py @@ -0,0 +1,24 @@ +import pytest + +from .. import GraphQLTestCase +from ...tests.test_types import with_local_registry + + +@with_local_registry +def test_graphql_test_case_deprecated_client(): + """ + Test that `GraphQLTestCase._client`'s should raise pending deprecation warning. + """ + + class TestClass(GraphQLTestCase): + GRAPHQL_SCHEMA = True + + def runTest(self): + pass + + tc = TestClass() + tc._pre_setup() + tc.setUpClass() + + with pytest.warns(PendingDeprecationWarning): + tc._client