Skip to content

Commit

Permalink
fix: empty list is not an empty value for list filters even when a cu…
Browse files Browse the repository at this point in the history
…stom filtering method is provided (#1450)

Co-authored-by: Thomas Leonard <thomas@loftorbital.com>
  • Loading branch information
tcleonard and Thomas Leonard authored Aug 11, 2023
1 parent 720db1f commit 0473f1a
Show file tree
Hide file tree
Showing 9 changed files with 450 additions and 94 deletions.
24 changes: 23 additions & 1 deletion graphene_django/compat.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import sys
from pathlib import PurePath

# For backwards compatibility, we import JSONField to have it available for import via
# this compat module (https://github.com/graphql-python/graphene-django/issues/1428).
# Django's JSONField is available in Django 3.2+ (the minimum version we support)
Expand All @@ -19,4 +22,23 @@ def __init__(self, *args, **kwargs):
RangeField,
)
except ImportError:
IntegerRangeField, ArrayField, HStoreField, RangeField = (MissingType,) * 4
IntegerRangeField, HStoreField, RangeField = (MissingType,) * 3

# For unit tests we fake ArrayField using JSONFields
if any(
PurePath(sys.argv[0]).match(p)
for p in [
"**/pytest",
"**/py.test",
"**/pytest/__main__.py",
]
):

class ArrayField(JSONField):
def __init__(self, *args, **kwargs):
if len(args) > 0:
self.base_field = args[0]
super().__init__(**kwargs)

else:
ArrayField = MissingType
23 changes: 23 additions & 0 deletions graphene_django/filter/filters/array_filter.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,36 @@
from django_filters.constants import EMPTY_VALUES
from django_filters.filters import FilterMethod

from .typed_filter import TypedFilter


class ArrayFilterMethod(FilterMethod):
def __call__(self, qs, value):
if value is None:
return qs
return self.method(qs, self.f.field_name, value)


class ArrayFilter(TypedFilter):
"""
Filter made for PostgreSQL ArrayField.
"""

@TypedFilter.method.setter
def method(self, value):
"""
Override method setter so that in case a custom `method` is provided
(see documentation https://django-filter.readthedocs.io/en/stable/ref/filters.html#method),
it doesn't fall back to checking if the value is in `EMPTY_VALUES` (from the `__call__` method
of the `FilterMethod` class) and instead use our ArrayFilterMethod that consider empty lists as values.
Indeed when providing a `method` the `filter` method below is overridden and replaced by `FilterMethod(self)`
which means that the validation of the empty value is made by the `FilterMethod.__call__` method instead.
"""
TypedFilter.method.fset(self, value)
if value is not None:
self.filter = ArrayFilterMethod(self)

def filter(self, qs, value):
"""
Override the default filter class to check first whether the list is
Expand Down
24 changes: 24 additions & 0 deletions graphene_django/filter/filters/list_filter.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,36 @@
from django_filters.filters import FilterMethod

from .typed_filter import TypedFilter


class ListFilterMethod(FilterMethod):
def __call__(self, qs, value):
if value is None:
return qs
return self.method(qs, self.f.field_name, value)


class ListFilter(TypedFilter):
"""
Filter that takes a list of value as input.
It is for example used for `__in` filters.
"""

@TypedFilter.method.setter
def method(self, value):
"""
Override method setter so that in case a custom `method` is provided
(see documentation https://django-filter.readthedocs.io/en/stable/ref/filters.html#method),
it doesn't fall back to checking if the value is in `EMPTY_VALUES` (from the `__call__` method
of the `FilterMethod` class) and instead use our ListFilterMethod that consider empty lists as values.
Indeed when providing a `method` the `filter` method below is overridden and replaced by `FilterMethod(self)`
which means that the validation of the empty value is made by the `FilterMethod.__call__` method instead.
"""
TypedFilter.method.fset(self, value)
if value is not None:
self.filter = ListFilterMethod(self)

def filter(self, qs, value):
"""
Override the default filter class to check first whether the list is
Expand Down
152 changes: 97 additions & 55 deletions graphene_django/filter/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from unittest.mock import MagicMock
from functools import reduce

import pytest
from django.db import models
Expand All @@ -25,15 +25,15 @@
)


STORE = {"events": []}


class Event(models.Model):
name = models.CharField(max_length=50)
tags = ArrayField(models.CharField(max_length=50))
tag_ids = ArrayField(models.IntegerField())
random_field = ArrayField(models.BooleanField())

def __repr__(self):
return f"Event [{self.name}]"


@pytest.fixture
def EventFilterSet():
Expand All @@ -48,6 +48,14 @@ class Meta:
tags__contains = ArrayFilter(field_name="tags", lookup_expr="contains")
tags__overlap = ArrayFilter(field_name="tags", lookup_expr="overlap")
tags = ArrayFilter(field_name="tags", lookup_expr="exact")
tags__len = ArrayFilter(
field_name="tags", lookup_expr="len", input_type=graphene.Int
)
tags__len__in = ArrayFilter(
field_name="tags",
method="tags__len__in_filter",
input_type=graphene.List(graphene.Int),
)

# Those are actually not usable and only to check type declarations
tags_ids__contains = ArrayFilter(field_name="tag_ids", lookup_expr="contains")
Expand All @@ -61,6 +69,14 @@ class Meta:
)
random_field = ArrayFilter(field_name="random_field", lookup_expr="exact")

def tags__len__in_filter(self, queryset, _name, value):
if not value:
return queryset.none()
return reduce(
lambda q1, q2: q1.union(q2),
[queryset.filter(tags__len=v) for v in value],
).distinct()

return EventFilterSet


Expand All @@ -83,68 +99,94 @@ def Query(EventType):
we are running unit tests in sqlite which does not have ArrayFields.
"""

events = [
Event(name="Live Show", tags=["concert", "music", "rock"]),
Event(name="Musical", tags=["movie", "music"]),
Event(name="Ballet", tags=["concert", "dance"]),
Event(name="Speech", tags=[]),
]

class Query(graphene.ObjectType):
events = DjangoFilterConnectionField(EventType)

def resolve_events(self, info, **kwargs):
events = [
Event(name="Live Show", tags=["concert", "music", "rock"]),
Event(name="Musical", tags=["movie", "music"]),
Event(name="Ballet", tags=["concert", "dance"]),
Event(name="Speech", tags=[]),
]

STORE["events"] = events

m_queryset = MagicMock(spec=QuerySet)
m_queryset.model = Event

def filter_events(**kwargs):
if "tags__contains" in kwargs:
STORE["events"] = list(
filter(
lambda e: set(kwargs["tags__contains"]).issubset(
set(e.tags)
),
STORE["events"],
class FakeQuerySet(QuerySet):
def __init__(self, model=None):
self.model = Event
self.__store = list(events)

def all(self):
return self

def filter(self, **kwargs):
queryset = FakeQuerySet()
queryset.__store = list(self.__store)
if "tags__contains" in kwargs:
queryset.__store = list(
filter(
lambda e: set(kwargs["tags__contains"]).issubset(
set(e.tags)
),
queryset.__store,
)
)
if "tags__overlap" in kwargs:
queryset.__store = list(
filter(
lambda e: not set(kwargs["tags__overlap"]).isdisjoint(
set(e.tags)
),
queryset.__store,
)
)
)
if "tags__overlap" in kwargs:
STORE["events"] = list(
filter(
lambda e: not set(kwargs["tags__overlap"]).isdisjoint(
set(e.tags)
),
STORE["events"],
if "tags__exact" in kwargs:
queryset.__store = list(
filter(
lambda e: set(kwargs["tags__exact"]) == set(e.tags),
queryset.__store,
)
)
)
if "tags__exact" in kwargs:
STORE["events"] = list(
filter(
lambda e: set(kwargs["tags__exact"]) == set(e.tags),
STORE["events"],
if "tags__len" in kwargs:
queryset.__store = list(
filter(
lambda e: len(e.tags) == kwargs["tags__len"],
queryset.__store,
)
)
)
return queryset

def union(self, *args):
queryset = FakeQuerySet()
queryset.__store = self.__store
for arg in args:
queryset.__store += arg.__store
return queryset

def mock_queryset_filter(*args, **kwargs):
filter_events(**kwargs)
return m_queryset
def none(self):
queryset = FakeQuerySet()
queryset.__store = []
return queryset

def mock_queryset_none(*args, **kwargs):
STORE["events"] = []
return m_queryset
def count(self):
return len(self.__store)

def mock_queryset_count(*args, **kwargs):
return len(STORE["events"])
def distinct(self):
queryset = FakeQuerySet()
queryset.__store = []
for event in self.__store:
if event not in queryset.__store:
queryset.__store.append(event)
queryset.__store = sorted(queryset.__store, key=lambda e: e.name)
return queryset

m_queryset.all.return_value = m_queryset
m_queryset.filter.side_effect = mock_queryset_filter
m_queryset.none.side_effect = mock_queryset_none
m_queryset.count.side_effect = mock_queryset_count
m_queryset.__getitem__.side_effect = lambda index: STORE[
"events"
].__getitem__(index)
def __getitem__(self, index):
return self.__store[index]

return m_queryset
return FakeQuerySet()

return Query


@pytest.fixture
def schema(Query):
return graphene.Schema(query=Query)
14 changes: 3 additions & 11 deletions graphene_django/filter/tests/test_array_field_contains_filter.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
import pytest

from graphene import Schema

from ...compat import ArrayField, MissingType


@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_contains_multiple(Query):
def test_array_field_contains_multiple(schema):
"""
Test contains filter on a array field of string.
"""

schema = Schema(query=Query)

query = """
query {
events (tags_Contains: ["concert", "music"]) {
Expand All @@ -32,13 +28,11 @@ def test_array_field_contains_multiple(Query):


@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_contains_one(Query):
def test_array_field_contains_one(schema):
"""
Test contains filter on a array field of string.
"""

schema = Schema(query=Query)

query = """
query {
events (tags_Contains: ["music"]) {
Expand All @@ -59,13 +53,11 @@ def test_array_field_contains_one(Query):


@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_contains_empty_list(Query):
def test_array_field_contains_empty_list(schema):
"""
Test contains filter on a array field of string.
"""

schema = Schema(query=Query)

query = """
query {
events (tags_Contains: []) {
Expand Down
Loading

0 comments on commit 0473f1a

Please sign in to comment.