Skip to content

Commit 6d1d224

Browse files
committed
incremental: allow nested defers at the same level
Replicates graphql/graphql-js@51f41eb
1 parent b3787b4 commit 6d1d224

File tree

6 files changed

+119
-66
lines changed

6 files changed

+119
-66
lines changed

docs/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@
184184
"graphql.execution.build_field_plan.FieldPlan",
185185
"graphql.execution.collect_fields.DeferUsage",
186186
"graphql.execution.execute.StreamArguments",
187+
"graphql.execution.execute.SubFieldPlan",
187188
"graphql.execution.execute.StreamUsage",
188189
"graphql.execution.map_async_iterable.map_async_iterable",
189190
"graphql.execution.incremental_publisher.CompletedResult",

src/graphql/execution/build_field_plan.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ class FieldGroup(NamedTuple):
3434

3535
fields: list[FieldDetails]
3636
defer_usages: DeferUsageSet | None = None
37-
known_defer_usages: DeferUsageSet | None = None
3837

3938
def to_nodes(self) -> list[FieldNode]:
4039
"""Return the field nodes in this group."""
@@ -59,22 +58,15 @@ class FieldPlan(NamedTuple):
5958

6059
grouped_field_set: GroupedFieldSet
6160
new_grouped_field_set_details_map: RefMap[DeferUsageSet, NewGroupedFieldSetDetails]
62-
new_defer_usages: list[DeferUsage]
6361

6462

6563
def build_field_plan(
6664
fields: dict[str, list[FieldDetails]],
6765
parent_defer_usages: DeferUsageSet | None = None,
68-
known_defer_usages: DeferUsageSet | None = None,
6966
) -> FieldPlan:
7067
"""Build a plan for executing fields."""
7168
if parent_defer_usages is None:
7269
parent_defer_usages = RefSet()
73-
if known_defer_usages is None:
74-
known_defer_usages = RefSet()
75-
76-
new_defer_usages: RefSet[DeferUsage] = RefSet()
77-
new_known_defer_usages: RefSet[DeferUsage] = RefSet(known_defer_usages)
7870

7971
grouped_field_set: GroupedFieldSet = {}
8072

@@ -93,9 +85,6 @@ def build_field_plan(
9385
in_original_result = True
9486
continue
9587
defer_usage_set.add(defer_usage)
96-
if defer_usage not in known_defer_usages:
97-
new_defer_usages.add(defer_usage)
98-
new_known_defer_usages.add(defer_usage)
9988
if in_original_result:
10089
defer_usage_set.clear()
10190
else:
@@ -112,7 +101,7 @@ def build_field_plan(
112101
if defer_usage_set == parent_defer_usages:
113102
field_group = grouped_field_set.get(response_key)
114103
if field_group is None: # pragma: no cover else
115-
field_group = FieldGroup([], defer_usage_set, new_known_defer_usages)
104+
field_group = FieldGroup([], defer_usage_set)
116105
grouped_field_set[response_key] = field_group
117106
field_group.fields.extend(field_details_list)
118107
continue
@@ -139,10 +128,8 @@ def build_field_plan(
139128

140129
field_group = new_grouped_field_set.get(response_key)
141130
if field_group is None: # pragma: no cover else
142-
field_group = FieldGroup([], defer_usage_set, new_known_defer_usages)
131+
field_group = FieldGroup([], defer_usage_set)
143132
new_grouped_field_set[response_key] = field_group
144133
field_group.fields.extend(field_details_list)
145134

146-
return FieldPlan(
147-
grouped_field_set, new_grouped_field_set_details_map, list(new_defer_usages)
148-
)
135+
return FieldPlan(grouped_field_set, new_grouped_field_set_details_map)

src/graphql/execution/collect_fields.py

Lines changed: 55 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
__all__ = [
2929
"CollectFieldsContext",
30+
"CollectedFields",
3031
"DeferUsage",
3132
"FieldDetails",
3233
"collect_fields",
@@ -69,13 +70,20 @@ class CollectFieldsContext(NamedTuple):
6970
visited_fragment_names: set[str]
7071

7172

73+
class CollectedFields(NamedTuple):
74+
"""Collected fields with new defer usages."""
75+
76+
fields: dict[str, list[FieldDetails]]
77+
new_defer_usages: list[DeferUsage]
78+
79+
7280
def collect_fields(
7381
schema: GraphQLSchema,
7482
fragments: dict[str, FragmentDefinitionNode],
7583
variable_values: dict[str, Any],
7684
runtime_type: GraphQLObjectType,
7785
operation: OperationDefinitionNode,
78-
) -> dict[str, list[FieldDetails]]:
86+
) -> CollectedFields:
7987
"""Collect fields.
8088
8189
Given a selection_set, collects all the fields and returns them.
@@ -87,6 +95,7 @@ def collect_fields(
8795
For internal use only.
8896
"""
8997
grouped_field_set: dict[str, list[FieldDetails]] = defaultdict(list)
98+
new_defer_usages: list[DeferUsage] = []
9099
context = CollectFieldsContext(
91100
schema,
92101
fragments,
@@ -96,8 +105,10 @@ def collect_fields(
96105
set(),
97106
)
98107

99-
collect_fields_impl(context, operation.selection_set, grouped_field_set)
100-
return grouped_field_set
108+
collect_fields_impl(
109+
context, operation.selection_set, grouped_field_set, new_defer_usages
110+
)
111+
return CollectedFields(grouped_field_set, new_defer_usages)
101112

102113

103114
def collect_subfields(
@@ -107,7 +118,7 @@ def collect_subfields(
107118
operation: OperationDefinitionNode,
108119
return_type: GraphQLObjectType,
109120
field_details: list[FieldDetails],
110-
) -> dict[str, list[FieldDetails]]:
121+
) -> CollectedFields:
111122
"""Collect subfields.
112123
113124
Given a list of field nodes, collects all the subfields of the passed in fields,
@@ -128,6 +139,7 @@ def collect_subfields(
128139
set(),
129140
)
130141
sub_grouped_field_set: dict[str, list[FieldDetails]] = defaultdict(list)
142+
new_defer_usages: list[DeferUsage] = []
131143

132144
for field_detail in field_details:
133145
node = field_detail.node
@@ -136,17 +148,18 @@ def collect_subfields(
136148
context,
137149
node.selection_set,
138150
sub_grouped_field_set,
151+
new_defer_usages,
139152
field_detail.defer_usage,
140153
)
141154

142-
return sub_grouped_field_set
155+
return CollectedFields(sub_grouped_field_set, new_defer_usages)
143156

144157

145158
def collect_fields_impl(
146159
context: CollectFieldsContext,
147160
selection_set: SelectionSetNode,
148161
grouped_field_set: dict[str, list[FieldDetails]],
149-
parent_defer_usage: DeferUsage | None = None,
162+
new_defer_usages: list[DeferUsage],
150163
defer_usage: DeferUsage | None = None,
151164
) -> None:
152165
"""Collect fields (internal implementation)."""
@@ -164,31 +177,39 @@ def collect_fields_impl(
164177
if not should_include_node(variable_values, selection):
165178
continue
166179
key = get_field_entry_key(selection)
167-
grouped_field_set[key].append(
168-
FieldDetails(selection, defer_usage or parent_defer_usage)
169-
)
180+
grouped_field_set[key].append(FieldDetails(selection, defer_usage))
170181
elif isinstance(selection, InlineFragmentNode):
171182
if not should_include_node(
172183
variable_values, selection
173184
) or not does_fragment_condition_match(schema, selection, runtime_type):
174185
continue
175186

176187
new_defer_usage = get_defer_usage(
177-
operation, variable_values, selection, parent_defer_usage
188+
operation, variable_values, selection, defer_usage
178189
)
179190

180-
collect_fields_impl(
181-
context,
182-
selection.selection_set,
183-
grouped_field_set,
184-
parent_defer_usage,
185-
new_defer_usage or defer_usage,
186-
)
191+
if new_defer_usage is None:
192+
collect_fields_impl(
193+
context,
194+
selection.selection_set,
195+
grouped_field_set,
196+
new_defer_usages,
197+
defer_usage,
198+
)
199+
else:
200+
new_defer_usages.append(new_defer_usage)
201+
collect_fields_impl(
202+
context,
203+
selection.selection_set,
204+
grouped_field_set,
205+
new_defer_usages,
206+
new_defer_usage,
207+
)
187208
elif isinstance(selection, FragmentSpreadNode): # pragma: no cover else
188209
frag_name = selection.name.value
189210

190211
new_defer_usage = get_defer_usage(
191-
operation, variable_values, selection, parent_defer_usage
212+
operation, variable_values, selection, defer_usage
192213
)
193214

194215
if new_defer_usage is None and (
@@ -205,14 +226,22 @@ def collect_fields_impl(
205226

206227
if new_defer_usage is None:
207228
visited_fragment_names.add(frag_name)
208-
209-
collect_fields_impl(
210-
context,
211-
fragment.selection_set,
212-
grouped_field_set,
213-
parent_defer_usage,
214-
new_defer_usage or defer_usage,
215-
)
229+
collect_fields_impl(
230+
context,
231+
fragment.selection_set,
232+
grouped_field_set,
233+
new_defer_usages,
234+
defer_usage,
235+
)
236+
else:
237+
new_defer_usages.append(new_defer_usage)
238+
collect_fields_impl(
239+
context,
240+
fragment.selection_set,
241+
grouped_field_set,
242+
new_defer_usages,
243+
new_defer_usage,
244+
)
216245

217246

218247
def get_defer_usage(

src/graphql/execution/execute.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@
7979
from .build_field_plan import (
8080
DeferUsageSet,
8181
FieldGroup,
82-
FieldPlan,
8382
GroupedFieldSet,
8483
NewGroupedFieldSetDetails,
8584
build_field_plan,
@@ -156,6 +155,14 @@ class StreamUsage(NamedTuple):
156155
field_group: FieldGroup
157156

158157

158+
class SubFieldPlan(NamedTuple):
159+
"""A plan for executing fields with defer usages."""
160+
161+
grouped_field_set: GroupedFieldSet
162+
new_grouped_field_set_details_map: RefMap[DeferUsageSet, NewGroupedFieldSetDetails]
163+
new_defer_usages: list[DeferUsage]
164+
165+
159166
class ExecutionContext:
160167
"""Data that must be available at all points during query execution.
161168
@@ -206,7 +213,7 @@ def __init__(
206213
if is_awaitable:
207214
self.is_awaitable = is_awaitable
208215
self._canceled_iterators: set[AsyncIterator] = set()
209-
self._field_plan_cache: dict[tuple, FieldPlan] = {}
216+
self._sub_field_plan_cache: dict[tuple, SubFieldPlan] = {}
210217
self._tasks: set[Awaitable] = set()
211218
self._stream_usages: RefMap[FieldGroup, StreamUsage] = RefMap()
212219

@@ -324,12 +331,10 @@ def execute_operation(
324331
)
325332
raise GraphQLError(msg, operation)
326333

327-
fields = collect_fields(
334+
fields, new_defer_usages = collect_fields(
328335
schema, self.fragments, self.variable_values, root_type, operation
329336
)
330-
grouped_field_set, new_grouped_field_set_details_map, new_defer_usages = (
331-
build_field_plan(fields)
332-
)
337+
grouped_field_set, new_grouped_field_set_details_map = build_field_plan(fields)
333338

334339
incremental_publisher = self.incremental_publisher
335340
new_defer_map = add_new_deferred_fragments(
@@ -1313,14 +1318,14 @@ def collect_and_execute_subfields(
13131318

13141319
def build_sub_field_plan(
13151320
self, return_type: GraphQLObjectType, field_group: FieldGroup
1316-
) -> FieldPlan:
1321+
) -> SubFieldPlan:
13171322
"""Collect subfields.
13181323
13191324
A memoized function for building subfield plans with regard to the return type.
13201325
Memoizing ensures the subfields are not repeatedly calculated, which saves
13211326
overhead when resolving lists of values.
13221327
"""
1323-
cache = self._field_plan_cache
1328+
cache = self._sub_field_plan_cache
13241329
# We cannot use the field_group itself as key for the cache, since it
13251330
# is not hashable as a list. We also do not want to use the field_group
13261331
# itself (converted to a tuple) as keys, since hashing them is slow.
@@ -1333,21 +1338,20 @@ def build_sub_field_plan(
13331338
if len(field_group) == 1 # optimize most frequent case
13341339
else (return_type, *map(id, field_group))
13351340
)
1336-
plan = cache.get(key)
1337-
if plan is None:
1338-
sub_fields = collect_subfields(
1341+
sub_field_plan = cache.get(key)
1342+
if sub_field_plan is None:
1343+
sub_fields, new_defer_usages = collect_subfields(
13391344
self.schema,
13401345
self.fragments,
13411346
self.variable_values,
13421347
self.operation,
13431348
return_type,
13441349
field_group.fields,
13451350
)
1346-
plan = build_field_plan(
1347-
sub_fields, field_group.defer_usages, field_group.known_defer_usages
1348-
)
1349-
cache[key] = plan
1350-
return plan
1351+
field_plan = build_field_plan(sub_fields, field_group.defer_usages)
1352+
sub_field_plan = SubFieldPlan(*field_plan, new_defer_usages)
1353+
cache[key] = sub_field_plan
1354+
return sub_field_plan
13511355

13521356
def map_source_to_response(
13531357
self, result_or_stream: ExecutionResult | AsyncIterable[Any]
@@ -2301,7 +2305,7 @@ def execute_subscription(
23012305
context.variable_values,
23022306
root_type,
23032307
context.operation,
2304-
)
2308+
).fields
23052309

23062310
first_root_field = next(iter(fields.items()))
23072311
response_name, field_details_list = first_root_field

src/graphql/validation/rules/single_field_subscriptions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def enter_operation_definition(
5252
variable_values,
5353
subscription_type,
5454
node,
55-
)
55+
).fields
5656
if len(fields) > 1:
5757
field_groups = list(fields.values())
5858
extra_field_groups = field_groups[1:]

0 commit comments

Comments
 (0)