Skip to content

Feature execute batch async #550

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
266 changes: 238 additions & 28 deletions gql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,24 @@ def _build_schema_from_introspection(
self.introspection = cast(IntrospectionQuery, execution_result.data)
self.schema = build_client_schema(self.introspection)

@staticmethod
def _get_event_loop() -> asyncio.AbstractEventLoop:
"""Get the current asyncio event loop.

Or create a new event loop if there isn't one (in a new Thread).
"""
try:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message="There is no current event loop"
)
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

return loop

@overload
def execute_sync(
self,
Expand Down Expand Up @@ -358,6 +376,58 @@ async def execute_async(
**kwargs,
)

@overload
async def execute_batch_async(
self,
requests: List[GraphQLRequest],
*,
serialize_variables: Optional[bool] = None,
parse_result: Optional[bool] = None,
get_execution_result: Literal[False] = ...,
**kwargs: Any,
) -> List[Dict[str, Any]]: ... # pragma: no cover

@overload
async def execute_batch_async(
self,
requests: List[GraphQLRequest],
*,
serialize_variables: Optional[bool] = None,
parse_result: Optional[bool] = None,
get_execution_result: Literal[True],
**kwargs: Any,
) -> List[ExecutionResult]: ... # pragma: no cover

@overload
async def execute_batch_async(
self,
requests: List[GraphQLRequest],
*,
serialize_variables: Optional[bool] = None,
parse_result: Optional[bool] = None,
get_execution_result: bool,
**kwargs: Any,
) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: ... # pragma: no cover

async def execute_batch_async(
self,
requests: List[GraphQLRequest],
*,
serialize_variables: Optional[bool] = None,
parse_result: Optional[bool] = None,
get_execution_result: bool = False,
**kwargs: Any,
) -> Union[List[Dict[str, Any]], List[ExecutionResult]]:
""":meta private:"""
async with self as session:
return await session.execute_batch(
requests,
serialize_variables=serialize_variables,
parse_result=parse_result,
get_execution_result=get_execution_result,
**kwargs,
)

@overload
def execute(
self,
Expand Down Expand Up @@ -430,17 +500,7 @@ def execute(
"""

if isinstance(self.transport, AsyncTransport):
# Get the current asyncio event loop
# Or create a new event loop if there isn't one (in a new Thread)
try:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message="There is no current event loop"
)
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop = self._get_event_loop()

assert not loop.is_running(), (
"Cannot run client.execute(query) if an asyncio loop is running."
Expand Down Expand Up @@ -537,7 +597,24 @@ def execute_batch(
"""

if isinstance(self.transport, AsyncTransport):
raise NotImplementedError("Batching is not implemented for async yet.")
loop = self._get_event_loop()

assert not loop.is_running(), (
"Cannot run client.execute_batch(query) if an asyncio loop is running."
" Use 'await client.execute_batch(query)' instead."
)

data = loop.run_until_complete(
self.execute_batch_async(
requests,
serialize_variables=serialize_variables,
parse_result=parse_result,
get_execution_result=get_execution_result,
**kwargs,
)
)

return data

else: # Sync transports
return self.execute_batch_sync(
Expand Down Expand Up @@ -675,17 +752,12 @@ def subscribe(
We need an async transport for this functionality.
"""

# Get the current asyncio event loop
# Or create a new event loop if there isn't one (in a new Thread)
try:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message="There is no current event loop"
)
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop = self._get_event_loop()

assert not loop.is_running(), (
"Cannot run client.subscribe(query) if an asyncio loop is running."
" Use 'await client.subscribe_async(query)' instead."
)

async_generator: Union[
AsyncGenerator[Dict[str, Any], None], AsyncGenerator[ExecutionResult, None]
Expand All @@ -699,11 +771,6 @@ def subscribe(
**kwargs,
)

assert not loop.is_running(), (
"Cannot run client.subscribe(query) if an asyncio loop is running."
" Use 'await client.subscribe_async(query)' instead."
)

try:
while True:
# Note: we need to create a task here in order to be able to close
Expand Down Expand Up @@ -1626,6 +1693,149 @@ async def execute(

return result.data

async def _execute_batch(
self,
requests: List[GraphQLRequest],
*,
serialize_variables: Optional[bool] = None,
parse_result: Optional[bool] = None,
validate_document: Optional[bool] = True,
**kwargs: Any,
) -> List[ExecutionResult]:
"""Execute multiple GraphQL requests in a batch, using
the async transport, returning a list of ExecutionResult objects.

:param requests: List of requests that will be executed.
:param serialize_variables: whether the variable values should be
serialized. Used for custom scalars and/or enums.
By default use the serialize_variables argument of the client.
:param parse_result: Whether gql will deserialize the result.
By default use the parse_results argument of the client.
:param validate_document: Whether we still need to validate the document.

The extra arguments are passed to the transport execute_batch method."""

# Validate document
if self.client.schema:

if validate_document:
for req in requests:
self.client.validate(req.document)

# Parse variable values for custom scalars if requested
if serialize_variables or (
serialize_variables is None and self.client.serialize_variables
):
requests = [
(
req.serialize_variable_values(self.client.schema)
if req.variable_values is not None
else req
)
for req in requests
]

results = await self.transport.execute_batch(requests, **kwargs)

# Unserialize the result if requested
if self.client.schema:
if parse_result or (parse_result is None and self.client.parse_results):
for result in results:
result.data = parse_result_fn(
self.client.schema,
req.document,
result.data,
operation_name=req.operation_name,
)

return results

@overload
async def execute_batch(
self,
requests: List[GraphQLRequest],
*,
serialize_variables: Optional[bool] = None,
parse_result: Optional[bool] = None,
get_execution_result: Literal[False] = ...,
**kwargs: Any,
) -> List[Dict[str, Any]]: ... # pragma: no cover

@overload
async def execute_batch(
self,
requests: List[GraphQLRequest],
*,
serialize_variables: Optional[bool] = None,
parse_result: Optional[bool] = None,
get_execution_result: Literal[True],
**kwargs: Any,
) -> List[ExecutionResult]: ... # pragma: no cover

@overload
async def execute_batch(
self,
requests: List[GraphQLRequest],
*,
serialize_variables: Optional[bool] = None,
parse_result: Optional[bool] = None,
get_execution_result: bool,
**kwargs: Any,
) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: ... # pragma: no cover

async def execute_batch(
self,
requests: List[GraphQLRequest],
*,
serialize_variables: Optional[bool] = None,
parse_result: Optional[bool] = None,
get_execution_result: bool = False,
**kwargs: Any,
) -> Union[List[Dict[str, Any]], List[ExecutionResult]]:
"""Execute multiple GraphQL requests in a batch, using
the async transport. This method sends the requests to the server all at once.

Raises a TransportQueryError if an error has been returned in any
ExecutionResult.

:param requests: List of requests that will be executed.
:param serialize_variables: whether the variable values should be
serialized. Used for custom scalars and/or enums.
By default use the serialize_variables argument of the client.
:param parse_result: Whether gql will deserialize the result.
By default use the parse_results argument of the client.
:param get_execution_result: return the full ExecutionResult instance instead of
only the "data" field. Necessary if you want to get the "extensions" field.

The extra arguments are passed to the transport execute method."""

# Validate and execute on the transport
results = await self._execute_batch(
requests,
serialize_variables=serialize_variables,
parse_result=parse_result,
**kwargs,
)

for result in results:
# Raise an error if an error is returned in the ExecutionResult object
if result.errors:
raise TransportQueryError(
str_first_element(result.errors),
errors=result.errors,
data=result.data,
extensions=result.extensions,
)

assert (
result.data is not None
), "Transport returned an ExecutionResult without data or errors"

if get_execution_result:
return results

return cast(List[Dict[str, Any]], [result.data for result in results])

async def fetch_schema(self) -> None:
"""Fetch the GraphQL schema explicitly using introspection.

Expand Down
Loading
Loading