Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP defer support #3753

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Release type: minor

@defer 👀
193 changes: 145 additions & 48 deletions poetry.lock

Large diffs are not rendered by default.

22 changes: 19 additions & 3 deletions strawberry/http/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
from typing_extensions import Literal, TypedDict
from typing import TYPE_CHECKING, Any, Optional, Union
from typing_extensions import Literal, NotRequired, TypedDict

from strawberry.types import InitialIncrementalExecutionResult

if TYPE_CHECKING:
from strawberry.types import ExecutionResult
Expand All @@ -14,14 +16,20 @@ class GraphQLHTTPResponse(TypedDict, total=False):
extensions: Optional[dict[str, object]]


def process_result(result: ExecutionResult) -> GraphQLHTTPResponse:
def process_result(
result: Union[ExecutionResult, InitialIncrementalExecutionResult],
) -> GraphQLHTTPResponse:
data: GraphQLHTTPResponse = {"data": result.data}

if result.errors:
data["errors"] = [err.formatted for err in result.errors]
if result.extensions:
data["extensions"] = result.extensions

if isinstance(result, InitialIncrementalExecutionResult):
data["hasNext"] = result.has_next
data["pending"] = result.pending

return data


Expand All @@ -35,8 +43,16 @@ class GraphQLRequestData:
protocol: Literal["http", "multipart-subscription"] = "http"


class IncrementalGraphQLHTTPResponse(TypedDict):
incremental: list[GraphQLHTTPResponse]
hasNext: bool
extensions: NotRequired[dict[str, Any]]
completed: list[GraphQLHTTPResponse]


__all__ = [
"GraphQLHTTPResponse",
"GraphQLRequestData",
"IncrementalGraphQLHTTPResponse",
"process_result",
]
103 changes: 98 additions & 5 deletions strawberry/http/async_base_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from strawberry.http import (
GraphQLHTTPResponse,
GraphQLRequestData,
IncrementalGraphQLHTTPResponse,
process_result,
)
from strawberry.http.ides import GraphQL_IDE
Expand Down Expand Up @@ -50,6 +51,28 @@
WebSocketResponse,
)

try:
from graphql.execution.execute import ExperimentalIncrementalExecutionResults
from graphql.execution.incremental_publisher import (
IncrementalDeferResult,
IncrementalResult,
IncrementalStreamResult,
InitialIncrementalExecutionResult,
PendingResult,
SubsequentIncrementalExecutionResult,
)

except ImportError as e:
from types import NoneType

print(e)

InitialIncrementalExecutionResult = NoneType
IncrementalResult = NoneType
IncrementalStreamResult = NoneType
SubsequentIncrementalExecutionResult = NoneType
PendingResult = NoneType


class AsyncHTTPRequestAdapter(abc.ABC):
@property
Expand Down Expand Up @@ -337,6 +360,29 @@ async def run(
except MissingQueryError as e:
raise HTTPException(400, "No GraphQL query found in the request") from e

if isinstance(result, ExperimentalIncrementalExecutionResults):

async def stream():
yield "---"
response = await self.process_result(request, result.initial_result)
yield self.encode_multipart_data(response, "-")

async for value in result.subsequent_results:
response = await self.process_subsequent_result(request, value)
yield self.encode_multipart_data(response, "-")

yield "--\r\n"

return await self.create_streaming_response(
request,
stream,
sub_response,
headers={
"Transfer-Encoding": "chunked",
"Content-Type": 'multipart/mixed; boundary="-"',
},
)

if isinstance(result, SubscriptionExecutionResult):
stream = self._get_stream(request, result)

Expand All @@ -360,12 +406,15 @@ async def run(
)

def encode_multipart_data(self, data: Any, separator: str) -> str:
encoded_data = self.encode_json(data)

return "".join(
[
f"\r\n--{separator}\r\n",
"Content-Type: application/json\r\n\r\n",
self.encode_json(data),
"\n",
"\r\n",
"Content-Type: application/json; charset=utf-8\r\n",
"\r\n",
encoded_data,
f"\r\n--{separator}",
]
)

Expand Down Expand Up @@ -475,9 +524,53 @@ async def parse_http_body(
protocol=protocol,
)

async def process_subsequent_result(
self,
request: Request,
result: "SubsequentIncrementalExecutionResult",
) -> IncrementalGraphQLHTTPResponse:
data = {
"incremental": [
await self.process_result(request, value)
for value in result.incremental
],
"completed": [
completed_result.formatted for completed_result in result.completed
],
"hasNext": result.has_next,
"extensions": result.extensions,
}

return data

async def process_result(
self, request: Request, result: ExecutionResult
self,
request: Request,
result: Union[ExecutionResult, InitialIncrementalExecutionResult],
) -> GraphQLHTTPResponse:
if isinstance(result, InitialIncrementalExecutionResult):
# TODO: fix this mess
from strawberry.types import (
InitialIncrementalExecutionResult as InitialIncrementalExecutionResultType,
)

# TODO: do this where we create ExecutionResult
# or maybe remove our wrappers and just GraphQL core's types
result = InitialIncrementalExecutionResultType(
data=result.data,
pending=[pending_result.formatted for pending_result in result.pending],
has_next=result.has_next,
extensions=result.extensions,
errors=result.errors,
)

result = await self.schema._handle_execution_result(
context=self.schema.execution_context,
result=result,
extensions_runner=self.schema.extensions_runner,
process_errors=self.schema.process_errors,
)

return process_result(result)

async def on_ws_connect(
Expand Down
40 changes: 38 additions & 2 deletions strawberry/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@
from graphql.execution import ExecutionContext as GraphQLExecutionContext
from graphql.execution import execute, subscribe
from graphql.execution.middleware import MiddlewareManager
from graphql.type.directives import specified_directives
from graphql.type.directives import (
GraphQLDeferDirective,
GraphQLStreamDirective,
specified_directives,
)
from graphql.validation import validate

from strawberry import relay
Expand Down Expand Up @@ -82,6 +86,9 @@
from strawberry.types.scalar import ScalarDefinition, ScalarWrapper
from strawberry.types.union import StrawberryUnion

ProcessErrors: TypeAlias = (
"Callable[[list[GraphQLError], Optional[ExecutionContext]], None]"
)
SubscriptionResult: TypeAlias = AsyncGenerator[
Union[PreExecutionError, ExecutionResult], None
]
Expand Down Expand Up @@ -263,7 +270,11 @@ class Query:
query=query_type,
mutation=mutation_type,
subscription=subscription_type if subscription else None,
directives=specified_directives + tuple(graphql_directives),
directives=(
specified_directives
+ tuple(graphql_directives)
+ (GraphQLDeferDirective, GraphQLStreamDirective)
),
types=graphql_types,
extensions={
GraphQLCoreConverter.DEFINITION_BACKREF: self,
Expand Down Expand Up @@ -353,6 +364,31 @@ def _create_execution_context(
provided_operation_name=operation_name,
)

# TODO: is this the right place to do this?
async def _handle_execution_result(
self,
context: ExecutionContext,
result: Union[GraphQLExecutionResult, ExecutionResult],
extensions_runner: SchemaExtensionsRunner,
process_errors: ProcessErrors | None,
) -> ExecutionResult:
# Set errors on the context so that it's easier
# to access in extensions
if result.errors:
context.errors = result.errors

if process_errors:
process_errors(result.errors, context)

if isinstance(result, GraphQLExecutionResult):
result = ExecutionResult(data=result.data, errors=result.errors)

# TODO: not correct when handling incremental results
result.extensions = await extensions_runner.get_extensions_results(context)

context.result = result # type: ignore # mypy failed to deduce correct type.
return result

@lru_cache
def get_type_by_name(
self, name: str
Expand Down
7 changes: 2 additions & 5 deletions strawberry/static/graphiql.html
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@
<link
crossorigin
rel="stylesheet"
href="https://unpkg.com/graphiql@3.0.9/graphiql.min.css"
integrity="sha384-yz3/sqpuplkA7msMo0FE4ekg0xdwdvZ8JX9MVZREsxipqjU4h8IRfmAMRcb1QpUy"
href="https://unpkg.com/graphiql@3.8.3/graphiql.min.css"
/>

<link
Expand All @@ -77,13 +76,11 @@
<div id="graphiql" class="graphiql-container">Loading...</div>
<script
crossorigin
src="https://unpkg.com/graphiql@3.0.9/graphiql.min.js"
integrity="sha384-Mjte+vxCWz1ZYCzszGHiJqJa5eAxiqI4mc3BErq7eDXnt+UGLXSEW7+i0wmfPiji"
src="https://unpkg.com/graphiql@3.8.3/graphiql.min.js"
></script>
<script
crossorigin
src="https://unpkg.com/@graphiql/plugin-explorer@1.0.2/dist/index.umd.js"
integrity="sha384-2oonKe47vfHIZnmB6ZZ10vl7T0Y+qrHQF2cmNTaFDuPshpKqpUMGMc9jgj9MLDZ9"
></script>
<script>
const EXAMPLE_QUERY = `# Welcome to GraphiQL 🍓
Expand Down
9 changes: 7 additions & 2 deletions strawberry/types/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from .base import get_object_definition, has_object_definition
from .execution import ExecutionContext, ExecutionResult, SubscriptionExecutionResult
from .execution import (
ExecutionContext,
ExecutionResult,
InitialIncrementalExecutionResult,
SubscriptionExecutionResult,
)
from .info import Info

__all__ = [
"ExecutionContext",
"ExecutionResult",
"Info",
"Info",
"InitialIncrementalExecutionResult",
"SubscriptionExecutionResult",
"get_object_definition",
"has_object_definition",
Expand Down
9 changes: 9 additions & 0 deletions strawberry/types/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,15 @@ class ExecutionResult:
extensions: Optional[dict[str, Any]] = None


@dataclasses.dataclass
class InitialIncrementalExecutionResult:
data: Optional[dict[str, Any]]
errors: Optional[list[GraphQLError]]
pending: list[Any]
has_next: bool
extensions: Optional[dict[str, Any]] = None


@dataclasses.dataclass
class PreExecutionError(ExecutionResult):
"""Differentiate between a normal execution result and an immediate error.
Expand Down
Empty file.
44 changes: 44 additions & 0 deletions tests/http/incremental/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import contextlib

import pytest

from tests.http.clients.base import HttpClient


@pytest.fixture
def http_client(http_client_class: type[HttpClient]) -> HttpClient:
with contextlib.suppress(ImportError):
import django

if django.VERSION < (4, 2):
pytest.skip(reason="Django < 4.2 doesn't async streaming responses")

from tests.http.clients.django import DjangoHttpClient

if http_client_class is DjangoHttpClient:
pytest.skip(reason="(sync) DjangoHttpClient doesn't support streaming")

with contextlib.suppress(ImportError):
from tests.http.clients.channels import SyncChannelsHttpClient

# TODO: why do we have a sync channels client?
if http_client_class is SyncChannelsHttpClient:
pytest.skip(reason="SyncChannelsHttpClient doesn't support streaming")

with contextlib.suppress(ImportError):
from tests.http.clients.async_flask import AsyncFlaskHttpClient
from tests.http.clients.flask import FlaskHttpClient

if http_client_class is FlaskHttpClient:
pytest.skip(reason="FlaskHttpClient doesn't support streaming")

if http_client_class is AsyncFlaskHttpClient:
pytest.xfail(reason="AsyncFlaskHttpClient doesn't support streaming")

with contextlib.suppress(ImportError):
from tests.http.clients.chalice import ChaliceHttpClient

if http_client_class is ChaliceHttpClient:
pytest.skip(reason="ChaliceHttpClient doesn't support streaming")

return http_client_class()
Loading
Loading