Skip to content

Commit 58cd387

Browse files
authored
Implementation of execute_batch for async transports (#550)
1 parent f0fd64d commit 58cd387

14 files changed

+1516
-254
lines changed

gql/client.py

Lines changed: 238 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,24 @@ def _build_schema_from_introspection(
184184
self.introspection = cast(IntrospectionQuery, execution_result.data)
185185
self.schema = build_client_schema(self.introspection)
186186

187+
@staticmethod
188+
def _get_event_loop() -> asyncio.AbstractEventLoop:
189+
"""Get the current asyncio event loop.
190+
191+
Or create a new event loop if there isn't one (in a new Thread).
192+
"""
193+
try:
194+
with warnings.catch_warnings():
195+
warnings.filterwarnings(
196+
"ignore", message="There is no current event loop"
197+
)
198+
loop = asyncio.get_event_loop()
199+
except RuntimeError:
200+
loop = asyncio.new_event_loop()
201+
asyncio.set_event_loop(loop)
202+
203+
return loop
204+
187205
@overload
188206
def execute_sync(
189207
self,
@@ -358,6 +376,58 @@ async def execute_async(
358376
**kwargs,
359377
)
360378

379+
@overload
380+
async def execute_batch_async(
381+
self,
382+
requests: List[GraphQLRequest],
383+
*,
384+
serialize_variables: Optional[bool] = None,
385+
parse_result: Optional[bool] = None,
386+
get_execution_result: Literal[False] = ...,
387+
**kwargs: Any,
388+
) -> List[Dict[str, Any]]: ... # pragma: no cover
389+
390+
@overload
391+
async def execute_batch_async(
392+
self,
393+
requests: List[GraphQLRequest],
394+
*,
395+
serialize_variables: Optional[bool] = None,
396+
parse_result: Optional[bool] = None,
397+
get_execution_result: Literal[True],
398+
**kwargs: Any,
399+
) -> List[ExecutionResult]: ... # pragma: no cover
400+
401+
@overload
402+
async def execute_batch_async(
403+
self,
404+
requests: List[GraphQLRequest],
405+
*,
406+
serialize_variables: Optional[bool] = None,
407+
parse_result: Optional[bool] = None,
408+
get_execution_result: bool,
409+
**kwargs: Any,
410+
) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: ... # pragma: no cover
411+
412+
async def execute_batch_async(
413+
self,
414+
requests: List[GraphQLRequest],
415+
*,
416+
serialize_variables: Optional[bool] = None,
417+
parse_result: Optional[bool] = None,
418+
get_execution_result: bool = False,
419+
**kwargs: Any,
420+
) -> Union[List[Dict[str, Any]], List[ExecutionResult]]:
421+
""":meta private:"""
422+
async with self as session:
423+
return await session.execute_batch(
424+
requests,
425+
serialize_variables=serialize_variables,
426+
parse_result=parse_result,
427+
get_execution_result=get_execution_result,
428+
**kwargs,
429+
)
430+
361431
@overload
362432
def execute(
363433
self,
@@ -430,17 +500,7 @@ def execute(
430500
"""
431501

432502
if isinstance(self.transport, AsyncTransport):
433-
# Get the current asyncio event loop
434-
# Or create a new event loop if there isn't one (in a new Thread)
435-
try:
436-
with warnings.catch_warnings():
437-
warnings.filterwarnings(
438-
"ignore", message="There is no current event loop"
439-
)
440-
loop = asyncio.get_event_loop()
441-
except RuntimeError:
442-
loop = asyncio.new_event_loop()
443-
asyncio.set_event_loop(loop)
503+
loop = self._get_event_loop()
444504

445505
assert not loop.is_running(), (
446506
"Cannot run client.execute(query) if an asyncio loop is running."
@@ -537,7 +597,24 @@ def execute_batch(
537597
"""
538598

539599
if isinstance(self.transport, AsyncTransport):
540-
raise NotImplementedError("Batching is not implemented for async yet.")
600+
loop = self._get_event_loop()
601+
602+
assert not loop.is_running(), (
603+
"Cannot run client.execute_batch(query) if an asyncio loop is running."
604+
" Use 'await client.execute_batch(query)' instead."
605+
)
606+
607+
data = loop.run_until_complete(
608+
self.execute_batch_async(
609+
requests,
610+
serialize_variables=serialize_variables,
611+
parse_result=parse_result,
612+
get_execution_result=get_execution_result,
613+
**kwargs,
614+
)
615+
)
616+
617+
return data
541618

542619
else: # Sync transports
543620
return self.execute_batch_sync(
@@ -675,17 +752,12 @@ def subscribe(
675752
We need an async transport for this functionality.
676753
"""
677754

678-
# Get the current asyncio event loop
679-
# Or create a new event loop if there isn't one (in a new Thread)
680-
try:
681-
with warnings.catch_warnings():
682-
warnings.filterwarnings(
683-
"ignore", message="There is no current event loop"
684-
)
685-
loop = asyncio.get_event_loop()
686-
except RuntimeError:
687-
loop = asyncio.new_event_loop()
688-
asyncio.set_event_loop(loop)
755+
loop = self._get_event_loop()
756+
757+
assert not loop.is_running(), (
758+
"Cannot run client.subscribe(query) if an asyncio loop is running."
759+
" Use 'await client.subscribe_async(query)' instead."
760+
)
689761

690762
async_generator: Union[
691763
AsyncGenerator[Dict[str, Any], None], AsyncGenerator[ExecutionResult, None]
@@ -699,11 +771,6 @@ def subscribe(
699771
**kwargs,
700772
)
701773

702-
assert not loop.is_running(), (
703-
"Cannot run client.subscribe(query) if an asyncio loop is running."
704-
" Use 'await client.subscribe_async(query)' instead."
705-
)
706-
707774
try:
708775
while True:
709776
# Note: we need to create a task here in order to be able to close
@@ -1626,6 +1693,149 @@ async def execute(
16261693

16271694
return result.data
16281695

1696+
async def _execute_batch(
1697+
self,
1698+
requests: List[GraphQLRequest],
1699+
*,
1700+
serialize_variables: Optional[bool] = None,
1701+
parse_result: Optional[bool] = None,
1702+
validate_document: Optional[bool] = True,
1703+
**kwargs: Any,
1704+
) -> List[ExecutionResult]:
1705+
"""Execute multiple GraphQL requests in a batch, using
1706+
the async transport, returning a list of ExecutionResult objects.
1707+
1708+
:param requests: List of requests that will be executed.
1709+
:param serialize_variables: whether the variable values should be
1710+
serialized. Used for custom scalars and/or enums.
1711+
By default use the serialize_variables argument of the client.
1712+
:param parse_result: Whether gql will deserialize the result.
1713+
By default use the parse_results argument of the client.
1714+
:param validate_document: Whether we still need to validate the document.
1715+
1716+
The extra arguments are passed to the transport execute_batch method."""
1717+
1718+
# Validate document
1719+
if self.client.schema:
1720+
1721+
if validate_document:
1722+
for req in requests:
1723+
self.client.validate(req.document)
1724+
1725+
# Parse variable values for custom scalars if requested
1726+
if serialize_variables or (
1727+
serialize_variables is None and self.client.serialize_variables
1728+
):
1729+
requests = [
1730+
(
1731+
req.serialize_variable_values(self.client.schema)
1732+
if req.variable_values is not None
1733+
else req
1734+
)
1735+
for req in requests
1736+
]
1737+
1738+
results = await self.transport.execute_batch(requests, **kwargs)
1739+
1740+
# Unserialize the result if requested
1741+
if self.client.schema:
1742+
if parse_result or (parse_result is None and self.client.parse_results):
1743+
for result in results:
1744+
result.data = parse_result_fn(
1745+
self.client.schema,
1746+
req.document,
1747+
result.data,
1748+
operation_name=req.operation_name,
1749+
)
1750+
1751+
return results
1752+
1753+
@overload
1754+
async def execute_batch(
1755+
self,
1756+
requests: List[GraphQLRequest],
1757+
*,
1758+
serialize_variables: Optional[bool] = None,
1759+
parse_result: Optional[bool] = None,
1760+
get_execution_result: Literal[False] = ...,
1761+
**kwargs: Any,
1762+
) -> List[Dict[str, Any]]: ... # pragma: no cover
1763+
1764+
@overload
1765+
async def execute_batch(
1766+
self,
1767+
requests: List[GraphQLRequest],
1768+
*,
1769+
serialize_variables: Optional[bool] = None,
1770+
parse_result: Optional[bool] = None,
1771+
get_execution_result: Literal[True],
1772+
**kwargs: Any,
1773+
) -> List[ExecutionResult]: ... # pragma: no cover
1774+
1775+
@overload
1776+
async def execute_batch(
1777+
self,
1778+
requests: List[GraphQLRequest],
1779+
*,
1780+
serialize_variables: Optional[bool] = None,
1781+
parse_result: Optional[bool] = None,
1782+
get_execution_result: bool,
1783+
**kwargs: Any,
1784+
) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: ... # pragma: no cover
1785+
1786+
async def execute_batch(
1787+
self,
1788+
requests: List[GraphQLRequest],
1789+
*,
1790+
serialize_variables: Optional[bool] = None,
1791+
parse_result: Optional[bool] = None,
1792+
get_execution_result: bool = False,
1793+
**kwargs: Any,
1794+
) -> Union[List[Dict[str, Any]], List[ExecutionResult]]:
1795+
"""Execute multiple GraphQL requests in a batch, using
1796+
the async transport. This method sends the requests to the server all at once.
1797+
1798+
Raises a TransportQueryError if an error has been returned in any
1799+
ExecutionResult.
1800+
1801+
:param requests: List of requests that will be executed.
1802+
:param serialize_variables: whether the variable values should be
1803+
serialized. Used for custom scalars and/or enums.
1804+
By default use the serialize_variables argument of the client.
1805+
:param parse_result: Whether gql will deserialize the result.
1806+
By default use the parse_results argument of the client.
1807+
:param get_execution_result: return the full ExecutionResult instance instead of
1808+
only the "data" field. Necessary if you want to get the "extensions" field.
1809+
1810+
The extra arguments are passed to the transport execute method."""
1811+
1812+
# Validate and execute on the transport
1813+
results = await self._execute_batch(
1814+
requests,
1815+
serialize_variables=serialize_variables,
1816+
parse_result=parse_result,
1817+
**kwargs,
1818+
)
1819+
1820+
for result in results:
1821+
# Raise an error if an error is returned in the ExecutionResult object
1822+
if result.errors:
1823+
raise TransportQueryError(
1824+
str_first_element(result.errors),
1825+
errors=result.errors,
1826+
data=result.data,
1827+
extensions=result.extensions,
1828+
)
1829+
1830+
assert (
1831+
result.data is not None
1832+
), "Transport returned an ExecutionResult without data or errors"
1833+
1834+
if get_execution_result:
1835+
return results
1836+
1837+
return cast(List[Dict[str, Any]], [result.data for result in results])
1838+
16291839
async def fetch_schema(self) -> None:
16301840
"""Fetch the GraphQL schema explicitly using introspection.
16311841

0 commit comments

Comments
 (0)