Skip to content

Commit

Permalink
Fix hasNextPage - revert to count. Fix after (graphql-python#986)
Browse files Browse the repository at this point in the history
Co-authored-by: Jonathan Kim <jkimbo@gmail.com>
  • Loading branch information
pcraciunoiu and jkimbo authored Jun 25, 2020
1 parent 3c6733e commit 3c229b6
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 52 deletions.
61 changes: 21 additions & 40 deletions graphene_django/debug/tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def resolve_reporter(self, info, **args):
assert result.data == expected


@pytest.mark.parametrize("max_limit,does_count", [(None, True), (100, False)])
def test_should_query_nested_field(graphene_settings, max_limit, does_count):
@pytest.mark.parametrize("max_limit", [None, 100])
def test_should_query_nested_field(graphene_settings, max_limit):
graphene_settings.RELAY_CONNECTION_MAX_LIMIT = max_limit

r1 = Reporter(last_name="ABA")
Expand Down Expand Up @@ -117,18 +117,11 @@ def resolve_reporter(self, info, **args):
assert not result.errors
query = str(Reporter.objects.order_by("pk")[:1].query)
assert result.data["__debug"]["sql"][0]["rawSql"] == query
if does_count:
assert "COUNT" in result.data["__debug"]["sql"][1]["rawSql"]
assert "tests_reporter_pets" in result.data["__debug"]["sql"][2]["rawSql"]
assert "COUNT" in result.data["__debug"]["sql"][3]["rawSql"]
assert "tests_reporter_pets" in result.data["__debug"]["sql"][4]["rawSql"]
assert len(result.data["__debug"]["sql"]) == 5
else:
assert len(result.data["__debug"]["sql"]) == 3
for i in range(len(result.data["__debug"]["sql"])):
assert "COUNT" not in result.data["__debug"]["sql"][i]["rawSql"]
assert "tests_reporter_pets" in result.data["__debug"]["sql"][1]["rawSql"]
assert "tests_reporter_pets" in result.data["__debug"]["sql"][2]["rawSql"]
assert "COUNT" in result.data["__debug"]["sql"][1]["rawSql"]
assert "tests_reporter_pets" in result.data["__debug"]["sql"][2]["rawSql"]
assert "COUNT" in result.data["__debug"]["sql"][3]["rawSql"]
assert "tests_reporter_pets" in result.data["__debug"]["sql"][4]["rawSql"]
assert len(result.data["__debug"]["sql"]) == 5

assert result.data["reporter"] == expected["reporter"]

Expand Down Expand Up @@ -175,8 +168,8 @@ def resolve_all_reporters(self, info, **args):
assert result.data == expected


@pytest.mark.parametrize("max_limit,does_count", [(None, True), (100, False)])
def test_should_query_connection(graphene_settings, max_limit, does_count):
@pytest.mark.parametrize("max_limit", [None, 100])
def test_should_query_connection(graphene_settings, max_limit):
graphene_settings.RELAY_CONNECTION_MAX_LIMIT = max_limit

r1 = Reporter(last_name="ABA")
Expand Down Expand Up @@ -219,20 +212,14 @@ def resolve_all_reporters(self, info, **args):
)
assert not result.errors
assert result.data["allReporters"] == expected["allReporters"]
if does_count:
assert len(result.data["__debug"]["sql"]) == 2
assert "COUNT" in result.data["__debug"]["sql"][0]["rawSql"]
query = str(Reporter.objects.all()[:1].query)
assert result.data["__debug"]["sql"][1]["rawSql"] == query
else:
assert len(result.data["__debug"]["sql"]) == 1
assert "COUNT" not in result.data["__debug"]["sql"][0]["rawSql"]
query = str(Reporter.objects.all()[:1].query)
assert result.data["__debug"]["sql"][0]["rawSql"] == query


@pytest.mark.parametrize("max_limit,does_count", [(None, True), (100, False)])
def test_should_query_connectionfilter(graphene_settings, max_limit, does_count):
assert len(result.data["__debug"]["sql"]) == 2
assert "COUNT" in result.data["__debug"]["sql"][0]["rawSql"]
query = str(Reporter.objects.all()[:1].query)
assert result.data["__debug"]["sql"][1]["rawSql"] == query


@pytest.mark.parametrize("max_limit", [None, 100])
def test_should_query_connectionfilter(graphene_settings, max_limit):
graphene_settings.RELAY_CONNECTION_MAX_LIMIT = max_limit

from ...filter import DjangoFilterConnectionField
Expand Down Expand Up @@ -278,13 +265,7 @@ def resolve_all_reporters(self, info, **args):
)
assert not result.errors
assert result.data["allReporters"] == expected["allReporters"]
if does_count:
assert len(result.data["__debug"]["sql"]) == 2
assert "COUNT" in result.data["__debug"]["sql"][0]["rawSql"]
query = str(Reporter.objects.all()[:1].query)
assert result.data["__debug"]["sql"][1]["rawSql"] == query
else:
assert len(result.data["__debug"]["sql"]) == 1
assert "COUNT" not in result.data["__debug"]["sql"][0]["rawSql"]
query = str(Reporter.objects.all()[:1].query)
assert result.data["__debug"]["sql"][0]["rawSql"] == query
assert len(result.data["__debug"]["sql"]) == 2
assert "COUNT" in result.data["__debug"]["sql"][0]["rawSql"]
query = str(Reporter.objects.all()[:1].query)
assert result.data["__debug"]["sql"][1]["rawSql"] == query
32 changes: 21 additions & 11 deletions graphene_django/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

import six
from django.db.models.query import QuerySet
from graphql_relay.connection.arrayconnection import connection_from_list_slice
from graphql_relay.connection.arrayconnection import (
connection_from_list_slice,
get_offset_with_default,
)
from promise import Promise

from graphene import NonNull
Expand Down Expand Up @@ -129,25 +132,32 @@ def resolve_queryset(cls, connection, queryset, info, args):
@classmethod
def resolve_connection(cls, connection, args, iterable, max_limit=None):
iterable = maybe_queryset(iterable)
# When slicing from the end, need to retrieve the iterable length.
if args.get("last"):
max_limit = None

if isinstance(iterable, QuerySet):
_len = max_limit or iterable.count()
list_length = iterable.count()
list_slice_length = (
min(max_limit, list_length) if max_limit is not None else list_length
)
else:
_len = max_limit or len(iterable)
list_length = len(iterable)
list_slice_length = (
min(max_limit, list_length) if max_limit is not None else list_length
)

after = get_offset_with_default(args.get("after"), -1) + 1

connection = connection_from_list_slice(
iterable,
iterable[after:],
args,
slice_start=0,
list_length=_len,
list_slice_length=_len,
slice_start=after,
list_length=list_length,
list_slice_length=list_slice_length,
connection_type=connection,
edge_type=connection.Edge,
pageinfo_type=PageInfo,
)
connection.iterable = iterable
connection.length = _len
connection.length = list_length
return connection

@classmethod
Expand Down
55 changes: 54 additions & 1 deletion graphene_django/tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,6 +1126,59 @@ class Query(graphene.ObjectType):
assert len(result.data["allReporters"]["edges"]) == 4


def test_should_have_next_page(graphene_settings):
graphene_settings.RELAY_CONNECTION_MAX_LIMIT = 6
reporters = [Reporter(**kwargs) for kwargs in REPORTERS]
Reporter.objects.bulk_create(reporters)
db_reporters = Reporter.objects.all()

class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node,)

class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)

schema = graphene.Schema(query=Query)
# Need first: 4 here to trigger the `has_next_page` logic in graphql-relay
# See `arrayconnection.py::connection_from_list_slice`:
# has_next_page=isinstance(first, int) and end_offset < upper_bound
query = """
query AllReporters($first: Int, $after: String) {
allReporters(first: $first, after: $after) {
pageInfo {
hasNextPage
endCursor
}
edges {
node {
id
}
}
}
}
"""

result = schema.execute(query, variable_values=dict(first=4))
assert not result.errors
assert len(result.data["allReporters"]["edges"]) == 4
assert result.data["allReporters"]["pageInfo"]["hasNextPage"]

last_result = result.data["allReporters"]["pageInfo"]["endCursor"]
result2 = schema.execute(query, variable_values=dict(first=4, after=last_result))
assert not result2.errors
assert len(result2.data["allReporters"]["edges"]) == 2
assert not result2.data["allReporters"]["pageInfo"]["hasNextPage"]
gql_reporters = (
result.data["allReporters"]["edges"] + result2.data["allReporters"]["edges"]
)

assert {to_global_id("ReporterType", reporter.id) for reporter in db_reporters} == {
gql_reporter["node"]["id"] for gql_reporter in gql_reporters
}


def test_should_preserve_prefetch_related(django_assert_num_queries):
class ReporterType(DjangoObjectType):
class Meta:
Expand Down Expand Up @@ -1172,7 +1225,7 @@ def resolve_films(root, info):
}
"""
schema = graphene.Schema(query=Query)
with django_assert_num_queries(2) as captured:
with django_assert_num_queries(3) as captured:
result = schema.execute(query)
assert not result.errors

Expand Down

0 comments on commit 3c229b6

Please sign in to comment.