Skip to content

Commit 94ff3de

Browse files
authored
Add hack for updating final result with data from root value (#1170)
1 parent 85f538c commit 94ff3de

File tree

4 files changed

+111
-13
lines changed

4 files changed

+111
-13
lines changed

ariadne/graphql.py

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from .format_error import format_error
3737
from .logger import log_error
3838
from .types import (
39+
BaseProxyRootValue,
3940
ErrorFormatter,
4041
ExtensionList,
4142
GraphQLResult,
@@ -146,6 +147,8 @@ async def graphql(
146147
`**kwargs`: any kwargs not used by `graphql` are passed to
147148
`graphql.graphql`.
148149
"""
150+
result_update: Optional[BaseProxyRootValue] = None
151+
149152
extension_manager = ExtensionManager(extensions, context_value)
150153

151154
with extension_manager.request():
@@ -200,7 +203,11 @@ async def graphql(
200203
if isawaitable(root_value):
201204
root_value = await root_value
202205

203-
result = execute(
206+
if isinstance(root_value, BaseProxyRootValue):
207+
result_update = root_value
208+
root_value = root_value.root_value
209+
210+
exec_result = execute(
204211
schema,
205212
document,
206213
root_value=root_value,
@@ -214,25 +221,35 @@ async def graphql(
214221
**kwargs,
215222
)
216223

217-
if isawaitable(result):
218-
result = await cast(Awaitable[ExecutionResult], result)
224+
if isawaitable(exec_result):
225+
exec_result = await cast(Awaitable[ExecutionResult], exec_result)
219226
except GraphQLError as error:
220-
return handle_graphql_errors(
227+
error_result = handle_graphql_errors(
221228
[error],
222229
logger=logger,
223230
error_formatter=error_formatter,
224231
debug=debug,
225232
extension_manager=extension_manager,
226233
)
227234

228-
return handle_query_result(
229-
result,
235+
if result_update:
236+
return result_update.update_result(error_result)
237+
238+
return error_result
239+
240+
result = handle_query_result(
241+
exec_result,
230242
logger=logger,
231243
error_formatter=error_formatter,
232244
debug=debug,
233245
extension_manager=extension_manager,
234246
)
235247

248+
if result_update:
249+
return result_update.update_result(result)
250+
251+
return result
252+
236253

237254
def graphql_sync(
238255
schema: GraphQLSchema,
@@ -321,6 +338,8 @@ def graphql_sync(
321338
`**kwargs`: any kwargs not used by `graphql_sync` are passed to
322339
`graphql.graphql_sync`.
323340
"""
341+
result_update: Optional[BaseProxyRootValue] = None
342+
324343
extension_manager = ExtensionManager(extensions, context_value)
325344

326345
with extension_manager.request():
@@ -379,7 +398,11 @@ def graphql_sync(
379398
"in synchronous query executor."
380399
)
381400

382-
result = execute_sync(
401+
if isinstance(root_value, BaseProxyRootValue):
402+
result_update = root_value
403+
root_value = root_value.root_value
404+
405+
exec_result = execute_sync(
383406
schema,
384407
document,
385408
root_value=root_value,
@@ -393,28 +416,38 @@ def graphql_sync(
393416
**kwargs,
394417
)
395418

396-
if isawaitable(result):
397-
ensure_future(cast(Awaitable[ExecutionResult], result)).cancel()
419+
if isawaitable(exec_result):
420+
ensure_future(cast(Awaitable[ExecutionResult], exec_result)).cancel()
398421
raise RuntimeError(
399422
"GraphQL execution failed to complete synchronously."
400423
)
401424
except GraphQLError as error:
402-
return handle_graphql_errors(
425+
error_result = handle_graphql_errors(
403426
[error],
404427
logger=logger,
405428
error_formatter=error_formatter,
406429
debug=debug,
407430
extension_manager=extension_manager,
408431
)
409432

410-
return handle_query_result(
411-
result,
433+
if result_update:
434+
return result_update.update_result(error_result)
435+
436+
return error_result
437+
438+
result = handle_query_result(
439+
exec_result,
412440
logger=logger,
413441
error_formatter=error_formatter,
414442
debug=debug,
415443
extension_manager=extension_manager,
416444
)
417445

446+
if result_update:
447+
return result_update.update_result(result)
448+
449+
return result
450+
418451

419452
async def subscribe(
420453
schema: GraphQLSchema,

ariadne/types.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
"ErrorFormatter",
3535
"ContextValue",
3636
"RootValue",
37+
"BaseProxyRootValue",
3738
"QueryParser",
3839
"QueryValidator",
3940
"ValidationRules",
@@ -228,6 +229,35 @@ async def get_context_value(request: Request, _):
228229
Callable[[Optional[Any], Optional[str], Optional[dict], DocumentNode], Any],
229230
]
230231

232+
233+
class BaseProxyRootValue:
234+
"""A `RootValue` wrapper that includes result JSON update logic.
235+
236+
Can be returned by the `RootValue` callable. Not used by Ariadne directly
237+
but part of the support for Ariadne GraphQL Proxy.
238+
239+
# Attributes
240+
241+
- `root_value: Optional[dict]`: `RootValue` to use during query execution.
242+
"""
243+
244+
__slots__ = ("root_value",)
245+
246+
root_value: Optional[dict]
247+
248+
def __init__(self, root_value: Optional[dict] = None):
249+
self.root_value = root_value
250+
251+
def update_result(self, result: GraphQLResult) -> GraphQLResult:
252+
"""An update function used to create a final `GraphQL` result tuple to
253+
create a JSON response from.
254+
255+
Default implementation in `BaseProxyRootValue` is a passthrough that
256+
returns `result` value without any changes.
257+
"""
258+
return result
259+
260+
231261
"""Type of `query_parser` option of GraphQL servers.
232262
233263
Enables customization of server's GraphQL parsing logic. If not set or `None`,

tests/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def type_defs():
2323
testContext: String
2424
testRoot: String
2525
testError: Boolean
26+
context: String
2627
}
2728
2829
type Mutation {

tests/test_graphql.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from graphql.validation.rules import ValidationRule
44

55
from ariadne import graphql, graphql_sync, subscribe
6+
from ariadne.types import BaseProxyRootValue
67

78

89
class AlwaysInvalid(ValidationRule):
@@ -12,6 +13,12 @@ def leave_operation_definition( # pylint: disable=unused-argument
1213
self.context.report_error(GraphQLError("Invalid"))
1314

1415

16+
class ProxyRootValue(BaseProxyRootValue):
17+
def update_result(self, result):
18+
success, data = result
19+
return success, {"updated": True, **data}
20+
21+
1522
def test_graphql_sync_executes_the_query(schema):
1623
success, result = graphql_sync(schema, {"query": '{ hello(name: "world") }'})
1724
assert success
@@ -51,8 +58,21 @@ def test_graphql_sync_prevents_introspection_query_when_option_is_disabled(schem
5158
)
5259

5360

61+
def test_graphql_sync_executes_the_query_using_result_update_obj(schema):
62+
success, result = graphql_sync(
63+
schema,
64+
{"query": "{ context }"},
65+
root_value=ProxyRootValue({"context": "Works!"}),
66+
)
67+
assert success
68+
assert result == {
69+
"data": {"context": "Works!"},
70+
"updated": True,
71+
}
72+
73+
5474
@pytest.mark.asyncio
55-
async def test_graphql_execute_the_query(schema):
75+
async def test_graphql_executes_the_query(schema):
5676
success, result = await graphql(schema, {"query": '{ hello(name: "world") }'})
5777
assert success
5878
assert result["data"] == {"hello": "Hello, world!"}
@@ -94,6 +114,20 @@ async def test_graphql_prevents_introspection_query_when_option_is_disabled(sche
94114
)
95115

96116

117+
@pytest.mark.asyncio
118+
async def test_graphql_executes_the_query_using_result_update_obj(schema):
119+
success, result = await graphql(
120+
schema,
121+
{"query": "{ context }"},
122+
root_value=ProxyRootValue({"context": "Works!"}),
123+
)
124+
assert success
125+
assert result == {
126+
"data": {"context": "Works!"},
127+
"updated": True,
128+
}
129+
130+
97131
@pytest.mark.asyncio
98132
async def test_subscription_returns_an_async_iterator(schema):
99133
success, result = await subscribe(schema, {"query": "subscription { ping }"})

0 commit comments

Comments
 (0)