Skip to content

Commit

Permalink
DjangoConnectionField slice: use max_limit first, if set (graphql-pyt…
Browse files Browse the repository at this point in the history
  • Loading branch information
pcraciunoiu authored Jun 6, 2020
1 parent 40e9c66 commit c002034
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 30 deletions.
83 changes: 58 additions & 25 deletions graphene_django/debug/tests/test_query.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import graphene
import pytest
from graphene.relay import Node
from graphene_django import DjangoConnectionField, DjangoObjectType

Expand All @@ -24,7 +25,7 @@ class Meta:

class Query(graphene.ObjectType):
reporter = graphene.Field(ReporterType)
debug = graphene.Field(DjangoDebug, name="_debug")
debug = graphene.Field(DjangoDebug, name="__debug")

def resolve_reporter(self, info, **args):
return Reporter.objects.first()
Expand All @@ -34,7 +35,7 @@ def resolve_reporter(self, info, **args):
reporter {
lastName
}
_debug {
__debug {
sql {
rawSql
}
Expand All @@ -43,7 +44,9 @@ def resolve_reporter(self, info, **args):
"""
expected = {
"reporter": {"lastName": "ABA"},
"_debug": {"sql": [{"rawSql": str(Reporter.objects.order_by("pk")[:1].query)}]},
"__debug": {
"sql": [{"rawSql": str(Reporter.objects.order_by("pk")[:1].query)}]
},
}
schema = graphene.Schema(query=Query)
result = schema.execute(
Expand All @@ -53,7 +56,10 @@ def resolve_reporter(self, info, **args):
assert result.data == expected


def test_should_query_nested_field():
@pytest.mark.parametrize("max_limit,does_count", [(None, True), (100, False)])
def test_should_query_nested_field(graphene_settings, max_limit, does_count):
graphene_settings.RELAY_CONNECTION_MAX_LIMIT = max_limit

r1 = Reporter(last_name="ABA")
r1.save()
r2 = Reporter(last_name="Griffin")
Expand Down Expand Up @@ -111,11 +117,18 @@ def resolve_reporter(self, info, **args):
assert not result.errors
query = str(Reporter.objects.order_by("pk")[:1].query)
assert result.data["__debug"]["sql"][0]["rawSql"] == query
assert "COUNT" in result.data["__debug"]["sql"][1]["rawSql"]
assert "tests_reporter_pets" in result.data["__debug"]["sql"][2]["rawSql"]
assert "COUNT" in result.data["__debug"]["sql"][3]["rawSql"]
assert "tests_reporter_pets" in result.data["__debug"]["sql"][4]["rawSql"]
assert len(result.data["__debug"]["sql"]) == 5
if does_count:
assert "COUNT" in result.data["__debug"]["sql"][1]["rawSql"]
assert "tests_reporter_pets" in result.data["__debug"]["sql"][2]["rawSql"]
assert "COUNT" in result.data["__debug"]["sql"][3]["rawSql"]
assert "tests_reporter_pets" in result.data["__debug"]["sql"][4]["rawSql"]
assert len(result.data["__debug"]["sql"]) == 5
else:
assert len(result.data["__debug"]["sql"]) == 3
for i in range(len(result.data["__debug"]["sql"])):
assert "COUNT" not in result.data["__debug"]["sql"][i]["rawSql"]
assert "tests_reporter_pets" in result.data["__debug"]["sql"][1]["rawSql"]
assert "tests_reporter_pets" in result.data["__debug"]["sql"][2]["rawSql"]

assert result.data["reporter"] == expected["reporter"]

Expand All @@ -133,7 +146,7 @@ class Meta:

class Query(graphene.ObjectType):
all_reporters = graphene.List(ReporterType)
debug = graphene.Field(DjangoDebug, name="_debug")
debug = graphene.Field(DjangoDebug, name="__debug")

def resolve_all_reporters(self, info, **args):
return Reporter.objects.all()
Expand All @@ -143,7 +156,7 @@ def resolve_all_reporters(self, info, **args):
allReporters {
lastName
}
_debug {
__debug {
sql {
rawSql
}
Expand All @@ -152,7 +165,7 @@ def resolve_all_reporters(self, info, **args):
"""
expected = {
"allReporters": [{"lastName": "ABA"}, {"lastName": "Griffin"}],
"_debug": {"sql": [{"rawSql": str(Reporter.objects.all().query)}]},
"__debug": {"sql": [{"rawSql": str(Reporter.objects.all().query)}]},
}
schema = graphene.Schema(query=Query)
result = schema.execute(
Expand All @@ -162,7 +175,10 @@ def resolve_all_reporters(self, info, **args):
assert result.data == expected


def test_should_query_connection():
@pytest.mark.parametrize("max_limit,does_count", [(None, True), (100, False)])
def test_should_query_connection(graphene_settings, max_limit, does_count):
graphene_settings.RELAY_CONNECTION_MAX_LIMIT = max_limit

r1 = Reporter(last_name="ABA")
r1.save()
r2 = Reporter(last_name="Griffin")
Expand All @@ -175,7 +191,7 @@ class Meta:

class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
debug = graphene.Field(DjangoDebug, name="_debug")
debug = graphene.Field(DjangoDebug, name="__debug")

def resolve_all_reporters(self, info, **args):
return Reporter.objects.all()
Expand All @@ -189,7 +205,7 @@ def resolve_all_reporters(self, info, **args):
}
}
}
_debug {
__debug {
sql {
rawSql
}
Expand All @@ -203,12 +219,22 @@ def resolve_all_reporters(self, info, **args):
)
assert not result.errors
assert result.data["allReporters"] == expected["allReporters"]
assert "COUNT" in result.data["_debug"]["sql"][0]["rawSql"]
query = str(Reporter.objects.all()[:1].query)
assert result.data["_debug"]["sql"][1]["rawSql"] == query

if does_count:
assert len(result.data["__debug"]["sql"]) == 2
assert "COUNT" in result.data["__debug"]["sql"][0]["rawSql"]
query = str(Reporter.objects.all()[:1].query)
assert result.data["__debug"]["sql"][1]["rawSql"] == query
else:
assert len(result.data["__debug"]["sql"]) == 1
assert "COUNT" not in result.data["__debug"]["sql"][0]["rawSql"]
query = str(Reporter.objects.all()[:1].query)
assert result.data["__debug"]["sql"][0]["rawSql"] == query


@pytest.mark.parametrize("max_limit,does_count", [(None, True), (100, False)])
def test_should_query_connectionfilter(graphene_settings, max_limit, does_count):
graphene_settings.RELAY_CONNECTION_MAX_LIMIT = max_limit

def test_should_query_connectionfilter():
from ...filter import DjangoFilterConnectionField

r1 = Reporter(last_name="ABA")
Expand All @@ -224,7 +250,7 @@ class Meta:
class Query(graphene.ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterType, fields=["last_name"])
s = graphene.String(resolver=lambda *_: "S")
debug = graphene.Field(DjangoDebug, name="_debug")
debug = graphene.Field(DjangoDebug, name="__debug")

def resolve_all_reporters(self, info, **args):
return Reporter.objects.all()
Expand All @@ -238,7 +264,7 @@ def resolve_all_reporters(self, info, **args):
}
}
}
_debug {
__debug {
sql {
rawSql
}
Expand All @@ -252,6 +278,13 @@ def resolve_all_reporters(self, info, **args):
)
assert not result.errors
assert result.data["allReporters"] == expected["allReporters"]
assert "COUNT" in result.data["_debug"]["sql"][0]["rawSql"]
query = str(Reporter.objects.all()[:1].query)
assert result.data["_debug"]["sql"][1]["rawSql"] == query
if does_count:
assert len(result.data["__debug"]["sql"]) == 2
assert "COUNT" in result.data["__debug"]["sql"][0]["rawSql"]
query = str(Reporter.objects.all()[:1].query)
assert result.data["__debug"]["sql"][1]["rawSql"] == query
else:
assert len(result.data["__debug"]["sql"]) == 1
assert "COUNT" not in result.data["__debug"]["sql"][0]["rawSql"]
query = str(Reporter.objects.all()[:1].query)
assert result.data["__debug"]["sql"][0]["rawSql"] == query
13 changes: 9 additions & 4 deletions graphene_django/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,15 @@ def resolve_queryset(cls, connection, queryset, info, args):
return connection._meta.node.get_queryset(queryset, info)

@classmethod
def resolve_connection(cls, connection, args, iterable):
def resolve_connection(cls, connection, args, iterable, max_limit=None):
iterable = maybe_queryset(iterable)
# When slicing from the end, need to retrieve the iterable length.
if args.get("last"):
max_limit = None
if isinstance(iterable, QuerySet):
_len = iterable.count()
_len = max_limit or iterable.count()
else:
_len = len(iterable)
_len = max_limit or len(iterable)
connection = connection_from_list_slice(
iterable,
args,
Expand Down Expand Up @@ -189,7 +192,9 @@ def connection_resolver(
# thus the iterable gets refiltered by resolve_queryset
# but iterable might be promise
iterable = queryset_resolver(connection, iterable, info, args)
on_resolve = partial(cls.resolve_connection, connection, args)
on_resolve = partial(
cls.resolve_connection, connection, args, max_limit=max_limit
)

if Promise.is_thenable(iterable):
return Promise.resolve(iterable).then(on_resolve)
Expand Down
44 changes: 43 additions & 1 deletion graphene_django/tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,6 +1084,48 @@ class Query(graphene.ObjectType):
assert result.data == expected


REPORTERS = [
dict(
first_name="First {}".format(i),
last_name="Last {}".format(i),
email="johndoe+{}@example.com".format(i),
a_choice=1,
)
for i in range(6)
]


def test_should_return_max_limit(graphene_settings):
graphene_settings.RELAY_CONNECTION_MAX_LIMIT = 4
reporters = [Reporter(**kwargs) for kwargs in REPORTERS]
Reporter.objects.bulk_create(reporters)

class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node,)

class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)

schema = graphene.Schema(query=Query)
query = """
query AllReporters {
allReporters {
edges {
node {
id
}
}
}
}
"""

result = schema.execute(query)
assert not result.errors
assert len(result.data["allReporters"]["edges"]) == 4


def test_should_preserve_prefetch_related(django_assert_num_queries):
class ReporterType(DjangoObjectType):
class Meta:
Expand Down Expand Up @@ -1130,7 +1172,7 @@ def resolve_films(root, info):
}
"""
schema = graphene.Schema(query=Query)
with django_assert_num_queries(3) as captured:
with django_assert_num_queries(2) as captured:
result = schema.execute(query)
assert not result.errors

Expand Down

0 comments on commit c002034

Please sign in to comment.