Skip to content

Commit 8487bc3

Browse files
committed
Fixing fragment spread oversight
1 parent 10fbf73 commit 8487bc3

File tree

3 files changed

+73
-6
lines changed

3 files changed

+73
-6
lines changed

graphene_mongo/fields.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -512,9 +512,11 @@ def chained_resolver(self, resolver, is_partial, root, info, **args):
512512
args.update(resolved._query)
513513
args_copy = args.copy()
514514
for arg_name, arg in args.copy().items():
515-
if "." in arg_name or arg_name not in self.model._fields_ordered + (
516-
'first', 'last', 'before', 'after') + tuple(
517-
self.filter_args.keys()):
515+
if "." in arg_name or arg_name not in (
516+
self.model._fields_ordered +
517+
('first', 'last', 'before', 'after') +
518+
tuple(self.filter_args.keys())
519+
):
518520
args_copy.pop(arg_name)
519521
if arg_name == '_id' and isinstance(arg, dict):
520522
operation = list(arg.keys())[0]
@@ -549,7 +551,7 @@ def connection_resolver(cls, resolver, connection_type, root, info, **args):
549551
if value:
550552
try:
551553
setattr(root, key, from_global_id(value)[1])
552-
except Exception as error:
554+
except Exception:
553555
pass
554556
iterable = resolver(root, info, **args)
555557

graphene_mongo/tests/test_utils.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
from ..utils import get_model_fields, is_valid_mongoengine_model
1+
from ..utils import get_model_fields, is_valid_mongoengine_model, get_query_fields
22
from .models import Article, Reporter, Child
3+
from . import types
4+
import graphene
35

46

57
def test_get_model_fields_no_duplication():
@@ -36,3 +38,66 @@ def test_get_base_model_fields():
3638

3739
def test_is_valid_mongoengine_mode():
3840
assert is_valid_mongoengine_model(Reporter)
41+
42+
43+
def test_get_query_fields():
44+
# Grab ResolveInfo objects from resolvers and set as nonlocal variables outside
45+
# Can't assert within resolvers, as the resolvers may not be run if there is an exception
46+
class Query(graphene.ObjectType):
47+
child = graphene.Field(types.ChildType)
48+
children = graphene.List(types.ChildUnionType)
49+
50+
def resolve_child(self, info, *args, **kwargs):
51+
test_get_query_fields.child_info = info
52+
53+
def resolve_children(self, info, *args, **kwargs):
54+
test_get_query_fields.children_info = info
55+
56+
query = """
57+
query Query {
58+
child {
59+
bar
60+
...testFragment
61+
}
62+
children {
63+
... on ChildType{
64+
baz
65+
...testFragment
66+
}
67+
... on AnotherChildType {
68+
qux
69+
}
70+
}
71+
}
72+
73+
fragment testFragment on ChildType {
74+
loc {
75+
type
76+
coordinates
77+
}
78+
}
79+
"""
80+
81+
schema = graphene.Schema(query=Query)
82+
schema.execute(query)
83+
84+
assert get_query_fields(test_get_query_fields.child_info) == {
85+
'bar': {},
86+
'loc': {
87+
'type': {},
88+
'coordinates': {}
89+
}
90+
}
91+
92+
assert get_query_fields(test_get_query_fields.children_info) == {
93+
'ChildType': {
94+
'baz': {},
95+
'loc': {
96+
'type': {},
97+
'coordinates': {}
98+
}
99+
},
100+
'AnotherChildType': {
101+
'qux': {}
102+
}
103+
}

graphene_mongo/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def collect_query_fields(node, fragments):
139139
leaf.name.value: collect_query_fields(leaf, fragments)
140140
})
141141
elif leaf.kind == 'fragment_spread':
142-
field.update(collect_query_fields(fragments[leaf['name']['value']],
142+
field.update(collect_query_fields(fragments[leaf.name.value],
143143
fragments))
144144
elif leaf.kind == 'inline_fragment':
145145
field.update({

0 commit comments

Comments
 (0)